diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py new file mode 100644 index 00000000000..1bb8f4e09dd --- /dev/null +++ b/tests/kernels/moe/test_batched_moe.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import pytest +import torch +import triton.language as tl + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel) + + +@dataclass +class BatchedMMConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + K: int + N: int + + +@dataclass +class BatchedMMTensors: + A: torch.Tensor # [E, max_tokens, K] + B: torch.Tensor # [E, K, N] - column major + C: torch.Tensor # [E, max_tokens, N] + num_expert_tokens: torch.Tensor # [E] + + @staticmethod + def make_tensors(config: BatchedMMConfig): + A = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.dtype) / 50.0 + B = torch.randn((config.num_experts, config.N, config.K), + device="cuda", + dtype=config.dtype) / 50.0 + C = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.N), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedMMTensors(A, B, C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + return C + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [512]) +@pytest.mark.parametrize("K", [256]) +@pytest.mark.parametrize("N", [512]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, + N: int, dtype: torch.dtype): + + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + tensors = BatchedMMTensors.make_tensors(config) + + test_output = tensors.C + ref_output = test_output.clone() + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32 + }[test_output.dtype] + invoke_moe_batched_triton_kernel( + tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config={ + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16 + }) + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, + tensors.num_expert_tokens) + #torch.cuda.synchronize() + #print (f"ref output {ref_output}") + #print (f"test output {test_output}") + + torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3) + + +@dataclass +class BatchedSiluMulConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + D: int + + +@dataclass +class BatchedSiluMulTensors: + input: torch.Tensor + output: torch.Tensor + expert_num_tokens: torch.Tensor + + @staticmethod + def make_tensors(config: BatchedSiluMulConfig): + input = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.D * 2), + device="cuda", + dtype=config.dtype) / 50.0 + output = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.D), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedSiluMulTensors(input, output, num_expert_tokens) + + +def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e].item() + out_part = output[e, :num_tokens, :] + in_part = input[e, :num_tokens, :] + torch.ops._C.silu_and_mul(out_part, in_part) + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [128]) +@pytest.mark.parametrize("D", [128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int, + dtype: torch.dtype): + + config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) + tensors = BatchedSiluMulTensors.make_tensors(config) + + test_out = tensors.output + ref_out = torch.zeros_like(test_out) + + ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) + + invoke_batched_silu_and_mul(test_out, tensors.input, + tensors.expert_num_tokens) + + torch.testing.assert_close(test_out, ref_out) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a17..7d24307e353 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -30,6 +30,11 @@ (224, 3072, 1536), ] +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @dataclasses.dataclass class MOETensors: @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, - 'topk_ids_': topk_ids, + 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, 'c_strides1': moe_tensors.c_strides1, 'ab_strides2': moe_tensors.ab_strides2, @@ -231,10 +236,7 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -276,10 +278,7 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, @@ -334,10 +333,7 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 425f36984a3..6822b88300e 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -13,6 +13,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, torch_moe_single) +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( @@ -29,6 +30,10 @@ EP_SIZE = [1, 4] TOP_KS = [2, 6] +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @@ -67,31 +72,33 @@ def test_fused_moe( else: e_map = None - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w1, w2, score, topk, e_map) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -112,7 +119,7 @@ def test_fused_moe( def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int): - print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) + #print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -191,22 +198,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + with set_current_vllm_config(vllm_config): + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -422,27 +431,28 @@ def test_fused_marlin_moe( topk_weights, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) - - marlin_output = torch.ops.vllm.fused_marlin_moe( - a, - qweight1, - qweight2, - scales1, - scales2, - score, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=e_map, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, - num_bits=num_bits, - is_k_full=is_k_full) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + + marlin_output = torch.ops.vllm.fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=e_map, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + num_bits=num_bits, + is_k_full=is_k_full) torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py new file mode 100644 index 00000000000..6e536c26870 --- /dev/null +++ b/tests/kernels/moe/test_pplx_moe.py @@ -0,0 +1,548 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import traceback +from typing import Callable, Optional + +import pytest +import torch + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = True +except ImportError: + has_pplx = False + +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +import vllm.model_executor.layers.fused_moe # noqa +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedExperts, BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import ( + PplxDispatchCombine) +from vllm.platforms import current_platform + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +P = ParamSpec("P") + +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_dispatch( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + max_num_tokens: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) + + assert tokens_per_expert.numel() == num_experts + + if max_num_tokens is None: + max_num_tokens = int(tokens_per_expert.max().item()) + + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, + device=a.device) + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx + 1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_combine(b_out, topk_weight, topk_ids): + num_tokens, topk = topk_ids.shape + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx + + 1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and w2.shape[1] == K + out = torch.zeros((num_experts, max_num_tokens, K), + dtype=b_a.dtype, + device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), + dtype=b_a.dtype, + device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul( + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_combine(out, topk_weight, topk_ids) + + +# TODO: same as torch_moe but with fused_topk factored out. +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + triton_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t, r, w): + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + + +def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): + assert torch.cuda.current_device() == pgi.local_rank + + topk = topk_ids.shape[1] + num_tokens, hidden_dim = a.shape + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + a.dtype, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_weight, + chunk_topk_ids, + num_experts, + None, + False, + ) + + b_a = b_a * 1.5 + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + False, + ) + + torch.cuda.synchronize() + + ata.destroy() + + num_tokens = a_chunk.shape[0] + + return out[:num_tokens] + + +def _pplx_dispatch_combine( + pgi: ProcessGroupInfo, + dp_size: int, + a, + score, + topk, + num_experts, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device + + topk_weight, topk_ids = fused_topk(a, score, topk, False) + k = a.shape[1] + + a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) + + torch_output = (a_rep.view(-1, topk, k) * 1.5 * + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(a.dtype) + + pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +# TODO: this test point does not work for M == 1 +@pytest.mark.parametrize("m", [4, 32, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_dispatch_combine( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, topk, e) + + +def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + assert torch.cuda.current_device() == pgi.local_rank + + hidden_dim = a.shape[1] + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + ) + + experts = BatchedExperts(a.shape[0]) + + fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + torch.cuda.synchronize() + + ata.destroy() + + return out + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + m, k = a.shape + e, _, n = w2.shape + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("m", [1, 2, 3, 32, 45, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 44734e9340a..3b5838a99fa 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,6 +7,7 @@ import torch from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform @@ -15,6 +16,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): """Matrix multiplication function that supports per-token input @@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization - ) + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c57e39f4250..28e73fdd7c0 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -30,6 +30,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -258,6 +261,7 @@ def per_block_cast_to_fp8( @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -380,15 +384,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - - if N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - vllm_config = VllmConfig() + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 104f23fd7cd..a4e9f83f0ea 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,6 +18,10 @@ pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # For test def native_per_token_group_quant_int8(x, @@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cb9658ce100..2cedaa06018 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,6 +23,7 @@ """ import contextlib import gc +import importlib.util import pickle import weakref from collections import namedtuple @@ -42,7 +43,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - supports_custom_op) + run_once, supports_custom_op) @dataclass @@ -912,6 +913,40 @@ def init_distributed_environment( "world group already initialized with a different world size") +PPLX_DID_INIT: bool = False + + +@run_once +def pplx_init(rank, world_size): + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + + if has_pplx and world_size > 1: + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init) + try: + global PPLX_DID_INIT + logger.info("PPLX_INIT rank=%d world=%d", rank, world_size) + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + uid_gpu = uid.cuda() + get_world_group().broadcast(uid_gpu, src=0) + logger.debug("PPLX_INIT UID = %s", uid_gpu) + uid = uid_gpu.to(device='cpu') + nvshmem_init(uid, rank, world_size) + PPLX_DID_INIT = True + except Exception as ex: + logger.error("Failed to initialize nvshmem for pplx: %s", ex) + + +@run_once +def pplx_finalize(): + global PPLX_DID_INIT + if PPLX_DID_INIT: + from pplx_kernels.nvshmem import nvshmem_finalize + logger.info("PPLX finalize") + nvshmem_finalize() + + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, @@ -1006,6 +1041,8 @@ def initialize_model_parallel( "DP rank %s, PP rank %s, TP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + pplx_init(rank, world_size) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, @@ -1081,6 +1118,9 @@ def get_tensor_model_parallel_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP + + pplx_finalize() + if _TP: _TP.destroy() _TP = None diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5..8e03541db87 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -31,7 +31,10 @@ @dataclass class DPMetadata: + max_tokens_across_dp: torch.Tensor + num_tokens_across_dp: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor + dp_rank_num_tokens: torch.Tensor @dataclass @@ -89,8 +92,16 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + #TODO device? (tms) + max_tokens_across_dp = torch.max( + num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + dp_rank_num_tokens = torch.tensor( + [num_tokens], + dtype=torch.uint32, + device=vllm_config.device_config.device) + dp_metadata = DPMetadata(max_tokens_across_dp, num_tokens_tensor, + cu_tokens_across_dp_cpu, dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384..b55a3ede12d 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -38,8 +38,8 @@ def get_config() -> Optional[Dict[str, Any]]: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + TritonExperts, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -48,4 +48,5 @@ def get_config() -> Optional[Dict[str, Any]]: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "TritonExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f83485..e52751eddf2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,10 +1,199 @@ # SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" -from typing import Optional +from typing import Optional, Tuple import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache + + +class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.ab_strides1 = ab_strides1 + self.c_strides1 = c_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides2 = c_strides2 + self.out_dtype = out_dtype + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + # Note that K, N are transposed + N, K = K, N + workspace1 = M * topk * max(2 * N, K) + workspace2 = M * topk * N + return (workspace1, workspace2, self.out_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + a1q = hidden_states + + assert w1_scale is not None + assert w2_scale is not None + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" + assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim( + ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2.shape[2], "W2 scale shape mismatch" + assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" + assert w1.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert self.ab_strides1.shape[0] == w1.shape[ + 0], "AB Strides 1 expert number mismatch" + assert self.c_strides1.shape[0] == w1.shape[ + 0], "C Strides 1 expert number mismatch" + assert self.ab_strides2.shape[0] == w2.shape[ + 0], "AB Strides 2 expert number mismatch" + assert self.c_strides2.shape[0] == w2.shape[ + 0], "C Strides 2 expert number mismatch" + assert self.out_dtype in [torch.half, + torch.bfloat16], "Invalid output dtype" + + M = a1q.shape[0] + _, N, K = w2.shape # because w1 + w2 are transposed + device = a1q.device + + assert w1.shape[1] == K + assert global_num_experts != -1 + assert a1q_scale is not None + + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.shape[1] + + per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + expert_offsets = torch.empty((global_num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + if expert_map is not None: + a_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + else: + a_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + c_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, a_map, + c_map, global_num_experts, N, K) + + a1q = _fp8_perm(a1q, a_map) + a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + + c1 = _resize_cache(workspace13, (M * topk, N * 2)) + c2 = _resize_cache(workspace2, (M * topk, N)) + c3 = _resize_cache(workspace13, (M * topk, K)) + + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, + expert_offsets[:-1], problem_sizes1, + self.ab_strides1, self.ab_strides1, self.c_strides1) + + self.activation(activation, c2, c1) + + a2q, a2q_scale = ops.scaled_fp8_quant( + c2, a2_scale, use_per_token_if_dynamic=per_act_token) + + if expert_map is not None: + c3.fill_(0) + + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, + self.ab_strides2, self.ab_strides2, self.c_strides2) + + c3 = c3[c_map, ...] + + return c3 + + +def modular_cutlass_moe_fp8( + per_act_token: bool, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype = torch.half, +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + StandardDispatchCombine( + per_channel_quant=per_act_token, + quant_dtype=torch.float8_e4m3fn, + ), + CutlassExperts( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ), + ) #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -15,7 +204,7 @@ def cutlass_moe_fp8( w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids_: torch.Tensor, + topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, @@ -57,7 +246,7 @@ def cutlass_moe_fp8( - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. + - out_dtype (torch.dtype): The output tensor type. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] @@ -69,112 +258,28 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - - assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ - 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - - local_topk_ids = topk_ids_ - if expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids_] != -1, - expert_map[topk_ids_], -1) - - topk = local_topk_ids.size(1) - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - if apply_router_weight_on_input: - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - # TODO: this only works for topK=1, will need to update for topK>1 - a = a * topk_weights.to(out_dtype) - - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device - - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - - a_map_initializer = torch.empty - c2_initializer = torch.empty - if expert_map is not None: - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - a_map_initializer = torch.zeros - c2_initializer = torch.zeros - - a_map = a_map_initializer((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_experts, n, - k) - - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) - - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, - expert_offsets[:-1], problem_sizes1, ab_strides1, - ab_strides1, c_strides1) - - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) - torch.ops._C.silu_and_mul(intermediate, c1) - - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, ab_strides2, - ab_strides2, c_strides2) - # Gather tokens - c2 = c2[c_map].view(m, topk, k) - if not apply_router_weight_on_input: - c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) - return c2.sum(dim=1) + + fn = modular_cutlass_moe_fp8( + per_act_token, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ) + + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 353c8cc9d59..4a0fb374bd4 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,16 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util from typing import Optional, Tuple import torch -import vllm.envs as envs -from vllm import _custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.utils import round_up @@ -19,6 +20,19 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +@functools.cache +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + +def _valid_deep_gemm_shape(M: int, N: int, K: int): + align = deep_gemm_block_shape()[0] + return align <= M and N % align == 0 and K % align == 0 + + def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -29,89 +43,120 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, aligned by `dg.get_m_alignment_for_contiguous_layout()`. """ if not has_deep_gemm: + logger.debug("DeepGemm disabled: deep_gemm not available.") return False - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - # Expert maps not supported yet. if expert_map is not None: + logger.debug("DeepGemm disabled: expert map NYI.") return False - align = dg.get_m_alignment_for_contiguous_layout() M = hidden_states.shape[0] _, K, N = w2.shape - - # For now, disable DeepGemm for small N until better permute/unpermute - # ops are available. - if N <= 512: + if not _valid_deep_gemm_shape(M, N, K): + logger.debug("DeepGemm disabled: unalinged problem size.") return False - if align > M or N % align != 0 or K % align != 0: + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug("DeepGemm disabled: invalid weight dtype(s).") return False - return (hidden_states.is_contiguous() and w1.is_contiguous() - and w2.is_contiguous()) - + if (not hidden_states.is_contiguous() or not w1.is_contiguous() + or not w2.is_contiguous()): + logger.debug( + "DeepGemm disabled: weights or activations not contiguous.") + return False -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - top_k_num = curr_topk_ids.shape[1] + return True + + +class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self): + super().__init__() + self.block_shape = deep_gemm_block_shape() + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + block_m = self.block_shape[0] + M_sum = (M * topk) + num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + workspace1 = M_sum * max(N * 2, K) + workspace2 = M_sum * N + return (workspace1, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + import deep_gemm as dg + + a1q = hidden_states + _, N, K = w1.shape + + assert global_num_experts != -1 + assert w2.shape[1] == K + + a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( + a1q, + a1q_scale, + topk_ids, + global_num_experts, + expert_map, + self.block_shape[0], + ) + + # Note: M_sum is different than the pre-permuted shape of a1q. + M_sum = a1q.shape[0] + workspace1 = _resize_cache(workspace13, (M_sum, N)) + workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) + workspace3 = _resize_cache(workspace13, (M_sum, K)) - tokens_in_chunk, _ = curr_hidden_states.shape + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) + self.activation(activation, workspace2, workspace1.view(-1, N)) - inv_perm: Optional[torch.Tensor] = None + a2q_scale: Optional[torch.Tensor] = None - num_tokens = top_k_num * tokens_in_chunk - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False, + self.block_shape) - # Permute according to sorted token ids. - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + workspace3 = workspace3[inv_perm, ...] - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) + return workspace3 -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M, topk = topk_weight.shape - K = curr_hidden.shape[1] - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) +def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn, + block_shape=deep_gemm_block_shape()), + DeepGemmExperts(), + ) def deep_gemm_moe_fp8( @@ -128,6 +173,7 @@ def deep_gemm_moe_fp8( expert_map: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input=False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -166,129 +212,20 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - assert expert_map is None, "Expert maps not supported yet" - - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert w1.dtype == torch.float8_e4m3fn - assert w2.dtype == torch.float8_e4m3fn - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ - 0] == hidden_states.shape[0], "Input scale shape mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] - if global_num_experts == -1: - global_num_experts = E - - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - - assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) - - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) - - block_m = dg.get_m_alignment_for_contiguous_layout() - block_shape = [block_m, block_m] - - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - workspace1 = workspace13[:M_sum * N].view(M_sum, N) - workspace2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - workspace3 = workspace13[:M_sum * K].view(M_sum, K) - - for chunk in range(num_chunks): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - a1q_scale: Optional[torch.Tensor] = None - - qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, - a1_scale, block_shape) - - (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, - curr_topk_ids, global_num_experts, - expert_map, block_m) - - # Adjust the intermediate cache size and config for the last chunk. - # Note that in most cases we only have one chunk so the cache size - # and config are already set correctly and do not need to be adjusted. - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - curr_M = sorted_token_ids.numel() - workspace1 = _resize_cache(workspace1, (curr_M, N)) - workspace2 = _resize_cache(workspace2, (curr_M, N // 2)) - workspace3 = _resize_cache(workspace3, (curr_M, K)) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1, - expert_ids) - - if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - a2q_scale: Optional[torch.Tensor] = None - - qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale, - block_shape) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) - - return out_hidden_states + fn = modular_deep_gemm_fused_moe_fp8() + return fn( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py new file mode 100644 index 00000000000..9b647a70d5e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + ): + super().__init__() + self.per_channel_quant = per_channel_quant + self.block_shape = block_shape + self.quant_dtype = quant_dtype + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if apply_router_weight_on_input: + topk = topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + self.per_channel_quant, + self.block_shape) + + return a1q, a1q_scale, None + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py new file mode 100644 index 00000000000..be700f7b2e9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -0,0 +1,791 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused batched MoE kernel.""" +from typing import List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + + +@triton.jit +def batched_silu_and_mul_kernel( + output, # [E, MAX_NUM_TOKENS, D] + input, # [E, MAX_NUM_TOKENS, D * 2] + expert_num_tokens, # [E] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): + + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # early exit + return + + pid_m = tl.program_id(axis=1) + cta_m_start = pid_m * BLOCK_M + if cta_m_start >= e_num_tokens: + # early exit + return + + cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im + cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + offs_m = tl.arange(0, BLOCK_M)[:, None] + mask_m = offs_m < cta_m_size + + cta_input_ptrs = cta_input_ptr + offs_m * stride_im + cta_output_ptrs = cta_output_ptr + offs_m * stride_om + + # offset by D + offs_D = tl.arange(0, BLOCK_D) + cta_input_ptrs = cta_input_ptrs + offs_D + cta_output_ptrs = cta_output_ptrs + offs_D + + for d in range(0, tl.cdiv(D, BLOCK_D)): + mask_D = offs_D < (D - (d * BLOCK_D)) + mask_tile = mask_m & mask_D + + x_tile = tl.load(cta_input_ptrs, mask=mask_tile, + other=0.0).to(dtype=tl.float32) + y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0) + + # silu and mul + out_tile = (x_tile * (1.0 / + (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) + out_tile = out_tile * y_tile + tl.store(cta_output_ptrs, out_tile, mask=mask_tile) + + cta_input_ptrs = cta_input_ptrs + BLOCK_D + cta_output_ptrs = cta_output_ptrs + BLOCK_D + + +@triton.jit +def moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): + + offs_k = tl.arange(0, BLOCK_K) + + if use_w8a16: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_n // group_n + b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + + offs_bsn * stride_bsn) + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + expert_id) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=mask_m, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + if use_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel( + a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + accumulator = moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n, + group_k, + # Meta-parameters + BLOCK_M, + BLOCK_N, + BLOCK_K, + compute_type, + use_fp8_w8a8, + use_int8_w8a16) + + # store in C + offs_cn = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + pid_mn = tl.program_id(axis=1) + #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + + cta_n_start * stride_cn) + + expert_triton_kernel( + a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel( + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor, + B_scale: torch.Tensor, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): + + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config['BLOCK_SIZE_M'] + BLOCK_N = config['BLOCK_SIZE_N'] + BLOCK_K = config['BLOCK_SIZE_K'] + assert max_num_tokens % BLOCK_M == 0 + + grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * + triton.cdiv(B.shape[1], BLOCK_N)) + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K) + + +def invoke_batched_silu_and_mul( + output: torch.Tensor, #[E, MAX_TOKENS, D] + input: torch.Tensor, #[E, MAX_TOKENS, D * 2] + expert_num_tokens: torch.Tensor): + + num_experts = output.size(0) + max_num_tokens = output.size(1) + D = output.size(2) + + BLOCK_D = 1024 + BLOCK_M = 1 + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16 + }[output.dtype] + + #print(f"compute type {compute_tl_dtype}") + + grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M)) + batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens, + output.stride(0), output.stride(1), + input.stride(0), input.stride(1), + compute_tl_dtype, D, BLOCK_M, BLOCK_D) + + +class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__(self, world_size: int, rank: int): + super().__init__() + self.world_size = world_size + self.rank = rank + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a1.shape[0] + + if apply_router_weight_on_input: + topk = topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + num_tokens = a1.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) + + b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), + dtype=a1.dtype, + device=a1.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = expert_counts[expert_id] + b_a1[expert_id, idx:idx + 1, :] = a1[token, :] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return b_a1, a1_scale, tokens_per_expert + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + num_tokens = topk_ids.shape[0] + num_experts = fused_expert_output.shape[0] + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=fused_expert_output.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(topk_ids.shape[1]): + expert_id = expert_ids[i] + if expert_id < num_experts: + idx = expert_counts[expert_id] + if apply_router_weight_on_input: + output[token, :] = output[ + token, :] + fused_expert_output[expert_id, + idx:idx + 1, :] + else: + output[ + token, :] = output[token, :] + fused_expert_output[ + expert_id, + idx:idx + 1, :] * topk_weights[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + +class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + assert block_shape is None + assert block_m is None + assert not use_fp8_w8a8, "NYI" + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + self.max_num_tokens = max_num_tokens + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[ + 1] if self.max_num_tokens is None else self.max_num_tokens + # TODO: *2 is a hack + workspace13 = num_experts * max_num_tokens * K * topk * 2 + workspace2 = max_num_tokens * N + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + assert hidden_states.dim() == 3 + assert expert_num_tokens is not None + + if self.max_num_tokens is None: + max_num_tokens = hidden_states.shape[1] + else: + max_num_tokens = self.max_num_tokens + + num_experts = global_num_experts + out = _resize_cache(workspace13, + (num_experts, max_num_tokens, w2.shape[1])) + num_local_experts = expert_num_tokens.numel() + + for expert in range(num_local_experts): + num = expert_num_tokens[expert] + assert num <= max_num_tokens, f"{num}, {max_num_tokens}" + if num > 0: + tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) + self.activation( + activation, tmp, + hidden_states[expert, :num, :] @ w1[expert].transpose( + 0, 1)) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + + return out + + +class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + self.max_num_tokens = max_num_tokens + assert not use_int8_w8a8, "NYI" + assert not use_int4_w4a16, "NYI" + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[ + 1] if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * max(K, N) + workspace2 = num_experts * max_num_tokens * (N // 2) + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + + num_tokens = topk_ids.size(0) + + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.shape[-1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + assert w1.shape[0] == E + assert w2.shape[0] == E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.shape, + w2.shape, + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, + (E, num_tokens, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) + + # MM1 + invoke_moe_batched_triton_kernel(A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + # Fix activations + assert activation == "silu" + invoke_batched_silu_and_mul(output=intermediate_cache2, + input=intermediate_cache1, + expert_num_tokens=expert_num_tokens) + + #qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + # TODO (varun) : support w8a8 + assert not self.use_fp8_w8a8 + #if self.use_fp8_w8a8: + # qintermediate_cache2, a2q_scale = _fp8_quantize( + # intermediate_cache2, a2_scale, self.block_shape) + + invoke_moe_batched_triton_kernel(A=intermediate_cache2, + B=w2, + C=intermediate_cache3, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede7..bb7658519bc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -10,16 +10,17 @@ import triton.language as tl import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -485,6 +486,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if use_fp8_w8a8 or use_int8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + M = A.shape[0] num_tokens = M * top_k @@ -962,6 +977,20 @@ def get_config_dtype_str( return None +# TODO: use scalar_type? +def get_config_qtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1131,7 +1160,10 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, allow_deep_gemm: bool = False) -> torch.Tensor: - if (allow_deep_gemm and use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better + # permute/unpermute ops are available. + N = w1.shape[1] + if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( @@ -1148,6 +1180,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: return dispatch_fused_experts_func(inplace)( @@ -1174,87 +1207,36 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def moe_kernel_prepare_input( - A: torch.Tensor, - B: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[List[int]] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if use_fp8_w8a8: - assert B_scale is not None - if block_shape is None: - # If weights are per-channel (per_channel_quant=True), then - # activations apply per-token quantization. Otherwise, assume - # activation tensor-wise fp8 quantization, dynamic or static - A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant) - else: - # activation block-wise fp8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a8: - assert B_scale is not None - if block_shape is None: - # activation channel-wise int8 quantization - assert (per_channel_quant - ), "int8 quantization only supports block or channel-wise" - A, A_scale = per_token_quant_int8(A) - else: - # activation block-wise int8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_int8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - - return A, A_scale - - -def fused_experts_impl(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None): +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2], \ + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1264,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, torch.float32, torch.float16, torch.bfloat16 ] - num_tokens, _ = hidden_states.shape + num_tokens = hidden_states.shape[0] E, N, _ = w1.shape K = w2.shape[1] if global_num_experts == -1: @@ -1279,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) + qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, @@ -1341,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( + qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, - B=w1, A_scale=a1_scale, - B_scale=w1_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) @@ -1360,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, - qa1_scale, + a1q_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1387,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - qintermediate_cache2, qa2_scale = moe_kernel_prepare_input( + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, - B=w2, A_scale=a2_scale, - B_scale=w2_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, - qa2_scale, + a2q_scale, w2_scale, w2_zp, curr_topk_weights, @@ -1537,3 +1514,234 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) + + +class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + self.block_m = block_m + self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + self.per_channel_quant = per_channel_quant + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + factor = num_experts if a.dim() == 3 else 1 + workspace1 = M * topk * max(N * 2, K) * factor + workspace2 = M * topk * N * factor + return (workspace1, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.shape[-1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + if global_num_experts == -1: + global_num_experts = E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.shape, + w2.shape, + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, + (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, + (num_tokens * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, + (num_tokens, top_k_num, K)) + + if hidden_states.dim() == 2: #block_m is None: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + else: + max_num_tokens = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, + hidden_states.shape[0] * + max_num_tokens, + device=hidden_states.device, + dtype=torch.int) + sorted_token_ids = sorted_token_ids.flatten() + expert_ids = torch.arange(0, + global_num_experts, + device=hidden_states.device, + dtype=torch.int) + expert_ids = torch.repeat_interleave(expert_ids, + max_num_tokens, + dim=0) + print(f"EXPERT_IDS {expert_ids}") + #num_tokens_post_padded = torch.tensor([num_tokens], + # device=hidden_states.device, + # dtype=torch.int32) + num_tokens_post_padded = torch.zeros(1, + device=hidden_states.device, + dtype=torch.int32) + num_tokens_post_padded.fill_(max_num_tokens) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + #print(f"P = {sorted_token_ids}, {hidden_states.shape}") + + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) + + a2q_scale: Optional[torch.Tensor] = None + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, + self.block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + return intermediate_cache3 + + +def modular_triton_fused_moe( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, +) -> mk.FusedMoEModularKernel: + qtype = get_config_qtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ) + return mk.FusedMoEModularKernel( + StandardDispatchCombine( + quant_dtype=qtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ), + TritonExperts( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ), + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3cdf3c97a7d..9e48d84be2f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib +import threading +import weakref from abc import abstractmethod +from dataclasses import dataclass from enum import Enum from typing import Callable, List, Optional, Tuple @@ -9,7 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config +from vllm.config import get_current_vllm_config, ParallelConfig from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -23,8 +27,17 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op +has_pplx = importlib.util.find_spec("pplx_kernels") is not None + if current_platform.is_cuda_alike(): - from .fused_moe import fused_experts + from .dispatch_combine import StandardDispatchCombine + from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts + from .fused_moe import TritonExperts, fused_experts + from .modular_kernel import (FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, + FusedMoEQuantizeDispatchCombine) + if has_pplx: + from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -34,6 +47,162 @@ fused_moe_pallas = None # type: ignore logger = init_logger(__name__) +MOE_DP_CHUNK_SIZE = 256 + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_pplx_kernels(self): + return self.use_ep and has_pplx + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, dp_size_, + ep_size_ and vllm's parallel config, determine what level's of parallelism + to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size = tp_size, + tp_rank = tp_rank, + dp_size = dp_size, + dp_rank = dp_rank, + ep_size = 1, + ep_rank = 0, + use_ep = False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor parallel. + # Update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size = 1, + tp_rank = 0, + dp_size = dp_size, + dp_rank = dp_rank, + ep_size = ep_size, + ep_rank = ep_rank, + use_ep = True) + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + in_dtype: torch.dtype # The activation type. + + # TODO: add more quantization params, blocked, per-token, etc. + block_size: int = 128 + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -50,6 +219,10 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError + def set_dispatch_combine( + self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + return False + @abstractmethod def apply( self, @@ -72,10 +245,46 @@ def apply( raise NotImplementedError +class AllToAllCache: + + def __init__(self): + self._cache = weakref.WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, **kwargs): + assert has_pplx + import pplx_kernels as pplx + + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + # TODO: should be intranode + instance = pplx.AllToAll.internode(**kwargs) + self._cache[key] = instance + return instance + + +# Global singleton +_all_to_all_cache = AllToAllCache() + + +# Factory function as a cleaner interface +def get_all_to_all(**kwargs): + return _all_to_all_cache.get_or_create(**kwargs) + + @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, moe: MoEConfig): + super().__init__() + self.fused_experts = fused_experts + self.moe = moe + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -173,6 +382,40 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + def set_dispatch_combine( + self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + assert self.fused_experts == fused_experts + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(dispatch_combine, + (BatchedDispatchCombine, PplxDispatchCombine)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + + self.fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + return True + def forward_cuda( self, layer: torch.nn.Module, @@ -191,6 +434,7 @@ def forward_cuda( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -203,7 +447,7 @@ def forward_cuda( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts( + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -213,7 +457,8 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + ) def forward_cpu( self, @@ -369,6 +614,57 @@ def determine_expert_map( return (local_num_experts, expert_map) +def construct_dispatch_combine(moe: MoEConfig, + quant_config: Optional[QuantizationConfig]) -> FusedMoEQuantizeDispatchCombine: + + dispatch_combine: FusedMoEQuantizeDispatchCombine = None + if moe.use_pplx_kernels: + logger.info("using pplx dispatch") + max_num_tokens = MOE_DP_CHUNK_SIZE + world_size = moe.ep_size + rank = moe.ep_rank + dp_size= moe.ep_size // moe.dp_size # dp_size actually means TP. + + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size= dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize))) + + dispatch_combine = PplxDispatchCombine( + all_to_all, + max_num_tokens, + world_size, + rank, # just for debugging + dp_size, + moe.in_dtype, + ) + elif True: + logger.info("using standard dispatch") + dispatch_combine = StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size + if quant_config is not None else None, + ) + else: + logger.info("using batched dispatch") + dispatch_combine = BatchedDispatchCombine( + moe.ep_size, + moe.ep_rank, + ) + return dispatch_combine + + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -419,21 +715,13 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - # Note: here we guard against accessing the TP and DP groups when - # uninitialized (this happens when testing) - self.tp_size = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() - self.dp_size = (dp_size - if dp_size is not None else get_dp_group().world_size) - self.dp_rank = (0 - if self.dp_size == 1 else get_dp_group().rank_in_group) - self.global_num_experts = num_experts - - # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() - use_ep = (vllm_config.parallel_config.enable_expert_parallel - and self.tp_size * self.dp_size > 1) + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()), + dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config) + + self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -444,28 +732,16 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix - if use_ep: - # Set TP size to 1 to adjust for EP and adjust EP size and rank - # for DP attention. - self.ep_rank = tp_rank + self.tp_size * self.dp_rank - self.tp_rank = 0 - self.ep_size = self.tp_size * self.dp_size - self.tp_size = 1 - + # Determine expert maps + if self.use_ep: self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) else: - # Adjust TP size for DP attention - self.tp_rank = tp_rank + self.tp_size * self.dp_rank - self.ep_rank = 0 - self.tp_size = self.tp_size * self.dp_size - self.ep_size = 1 - self.local_num_experts = self.global_num_experts - self.expert_map = None + self.local_num_experts, self.expert_map = (self.global_num_experts, None) + self.top_k = top_k - self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -489,14 +765,35 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, # TODO: is this right? + ) + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. + quant_method: Optional[FusedMoEMethodBase] = None + if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + quant_method = UnquantizedFusedMoEMethod(moe) else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None + quant_method = quant_config.get_quant_method(self, prefix) + assert isinstance(quant_method, FusedMoEMethodBase) + + assert quant_method is not None + self.quant_method = quant_method + + dispatch_combine = construct_dispatch_combine(moe, quant_config) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + + if not success: + logger.warning("DP+EP not supported for %s.", + type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -516,6 +813,38 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -832,21 +1161,91 @@ def naive_multicast(self, x: torch.Tensor, return buffer + # TODO: will this be cudagraph-able? (probably not) + # This should not be necessary. + def invalid_pplx(self, hidden_states: torch.Tensor) -> bool: + return has_pplx and hidden_states.shape[0] < self.dp_size + + def must_reduce_shared_outputs(self) -> bool: + return self.dp_size > 1 and self.use_ep and has_pplx + + def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): + """ + The pplx combine kernel reduce across GPU ranks by default. The pplx kernels are + used when EP is enabled. In that case, this function is a no-op. + """ + if self.dp_size > 1 and self.use_ep and has_pplx: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: + if self.use_direct_call or self.invalid_pplx(hidden_states): return self.forward_impl(hidden_states, router_logits) else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor): + + full_final_hidden_states = torch.empty_like(full_hidden_states) + + def process_chunk(chunk_start, chunk_end, skip_result_store=False): + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) + + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) + + max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE + + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) + + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) + + return full_final_hidden_states + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + if self.dp_size > 1 and self.use_ep and has_pplx: + return self.forward_impl_chunked(hidden_states, router_logits) if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + ctx = get_forward_context() + cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) @@ -880,7 +1279,8 @@ def forward_impl(self, hidden_states: torch.Tensor, all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + if self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py new file mode 100644 index 00000000000..eec5a7406d9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +# +# This file defines a set of base classes used to make MoE kernels more modular. +# The goal is to be able to utilize different communication mechanisms with +# any fused MoE kernel without needing to have combinatoric implementations. +# +# The fused moe kernels are broken down into the following components: +# +# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] +# +# Each component will be independent of the others except for +# [Quantize-Dispatch] and `[Combine] (see below). The components can then be +# mixed and matched with so that DP+EP can be supported easily for multiple +# MoE kernel implementations. +# +# The following main classes are defined: +# * FusedMoEQuantizeDispatchCombine - an abstract base class for quantization, +# dispatching and combing. The dispatch method takes care of any needed +# quantization and the combine method applies weights and does the final +# reduction of the output. +# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# MoE operation. One important feature to note is that this class does not +# apply topk weights or reduce the final output. +# * FusedMoEModularKernel - an interface class that combines a +# FusedMoEQuantizeDispatchCombine and a FusedMoEPermuteExpertsUnpermute to +# provide the standard fused MoE kernel interface. +# +# [Quantize-Dispatch] and [Combine] functionality are bundled into a single +# class `FusedMoEQuantizeDispatchCombine` since they could use collective +# communication mechanisms that need to be consistent. +# + + +def _moe_problem_size( + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, +) -> Tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation tensors is + not obvious. It needs to be done this way specifically due to subtle issues + with particular kernels, e.g. the int4 kernels divide the trailing dimension + by two, so it's not "correct" to extract N or K from the trailing dimension + of w1 or w2. Similarly, some kernels transpose the weights, so this needs + to be kept in mind. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.shape + K = w2.shape[1] + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.shape[0] == a1.shape[0], \ + f"{topk_ids.shape[0]} != {a1.shape[0]}" + M = a1.shape[0] + else: + assert a1.dim() == 3 + assert a1.shape[0] == E + M = a1.shape[1] # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.shape[1] + + return E, M, N, K, topk + + +class FusedMoEQuantizeDispatchCombine(ABC): + """ + An abstract base class for the [Quantize-Dispatch] and [Combine] steps + described above. + """ + + @abstractmethod + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform any quantization (and/or) dispatching needed + for this kernel. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + + Returns a tuple of: + - quantized + dispatched a. + - quantized + dispatched a1_scales. + """ + raise NotImplementedError + + @abstractmethod + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + """ + raise NotImplementedError + + +class FusedMoEPermuteExpertsUnpermute(ABC): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ + + @abstractmethod + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + """ + Compute the number of elements for the temporary outputs of the two + gemms and activation in the fused expert function. Since the + gemms are independent, the workspace for the first gemm can be shared + with the workspace for the last gemm. + + Returns a tuple of: + - Number of workspace13 elements: must be large enough to hold the + result of either expert gemm. + - Number of workspace2 elements: must be large enough to hold the + result of the activation function. + - Workspace type: The dtype to use for the workspace tensors. + """ + raise NotImplementedError + + def activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor) -> None: + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(output, input) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + @abstractmethod + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + This function computes the intermediate result of a Mixture of Experts + (MoE) layer using two sets of weights, w1 and w2. + + Parameters: + - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE + layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_ids (torch.Tensor): A map of row to expert id. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be + used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs + must be large enough to hold output of either MoE gemm. + - workspace2 (torch.Tensor): A scratch tensor used for the activation + function. + - expert_num_tokens: An optional tensor containing the number of tokens + assigned to each expert when using batched experts format input. + + Returns: + - torch.Tensor: The unweighted, unreduced output tensor + """ + raise NotImplementedError + + +class FusedMoEModularKernel(torch.nn.Module): + """ + This class combines a FusedMoEQuantizeDispatchCombine instance and + a FusedMoEPermuteExpertsUnpermute to provide an interface that + is compatible with the `fused_experts` function in fused_moe.py. + + It takes care of managing any required scratch space. + + Note: Instances of this class should only be used for a single model + layer due to any layer specific state that may be used by the component + objects. + """ + + def __init__( + self, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + fused_experts: FusedMoEPermuteExpertsUnpermute, + ): + super().__init__() + self.dispatch_combine = dispatch_combine + self.fused_experts = fused_experts + + def forward( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + ) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets + of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states: (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The topk weights applied at the end of + the layer. + - topk_ids (torch.Tensor): A map of row to expert id. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is + 1. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + a1 = hidden_states + E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) + + if global_num_experts == -1: + global_num_experts = E + + output = a1 if inplace else torch.empty_like(a1) + + workspace13_shape, workspace2_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes(a1, M, N, K, top_k, + global_num_experts)) + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.empty(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.empty(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) + + a1q, a1q_scale, expert_num_tokens = self.dispatch_combine.dispatch( + a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, + expert_map, apply_router_weight_on_input) + + fused_out = self.fused_experts.apply( + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + + self.dispatch_combine.combine(output, fused_out, topk_weights, + topk_ids, apply_router_weight_on_input) + + return output diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py new file mode 100644 index 00000000000..0d89f3f22c3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.shape[1] + + tokens_in_chunk, _ = curr_hidden_states.shape + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, + apply_router_weight_on_input: bool, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.shape + K = curr_hidden.shape[-1] + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + if not apply_router_weight_on_input: + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py new file mode 100644 index 00000000000..99049175250 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional, Tuple + +import pplx_kernels as pplx +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +# Note use: layer.get_all_to_all() to get an AllToAll instance +# The max_num_tokens, world_size and dp_size must be the same +# as the ones used to create the AllToAll. +class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__(self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + rank: int, + dp_size: int, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[List[int]] = None): + super().__init__() + assert max_num_tokens > 0 + self.a2a = a2a + self.block_shape = block_shape + self.max_num_tokens = max_num_tokens + self.world_size = world_size + self.rank = rank + self.dp_size = dp_size + self.quant_dtype = quant_dtype + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + rank_topk_weights: torch.Tensor, + rank_topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + num_tokens = a1.shape[0] # M + hidden_dim = a1.shape[-1] # K + + assert rank_topk_ids.shape[0] == num_tokens + # assert expert_map is None, "NYI" + + # Is this always going to be a1.device? + device = a1.device + + if apply_router_weight_on_input: + topk = rank_topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1 = a1 * rank_topk_weights.to(a1.dtype) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + per_act_token, + self.block_shape) + + rem_experts = num_experts % self.world_size + num_local_experts = ((num_experts // self.world_size) + + (1 if self.rank < rem_experts else 0)) + + expert_num_tokens = torch.empty( + num_local_experts, + dtype=torch.int32, + device=device, + ) + + num_dp = self.world_size // self.dp_size + expert_x = torch.empty( + (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), + dtype=a1q.dtype, + device=device, + ) + + expert_x_scale: Optional[torch.Tensor] = None + if a1q.dtype.itemsize == 1: + float32_size = torch.float32.itemsize + block_size = (self.block_shape[0] if self.block_shape is not None + else 1) * float32_size + expert_x_scale = torch.empty( + ( + num_experts, + expert_x.size(1), + (expert_x.size(2) + block_size - 1) // block_size, + ), + dtype=torch.float32, + device=device, + ) + + # This argument is optional, defaults to indices.shape[0] + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=a1q.device) + + # TODO: optimize this? + indices = rank_topk_ids.to(dtype=torch.uint32) + + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=indices, + bound_m=bound_m, + ) + + return expert_x, expert_x_scale, expert_num_tokens + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + # This argument is optional + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], + dtype=torch.uint32, + device=fused_expert_output.device) + + assert topk_ids.shape[0] <= num_tokens + assert output.shape[0] <= self.max_num_tokens, \ + f"{output.shape[0]} <= {self.max_num_tokens}" + assert output.shape[1] == fused_expert_output.shape[-1] + + # Set weights to 1 if we did them in dispatch. This is hacky. + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + + self.a2a.combine(out_tokens=output, + indices=topk_ids.to(torch.uint32), + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py new file mode 100644 index 00000000000..5ddb0e66842 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional, Tuple + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts + + +class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + allow_deep_gemm: bool = False): + super().__init__() + self.triton_expert = TritonExperts(use_fp8_w8a8, use_int8_w8a8, + use_int4_w4a16, use_int8_w8a16, + per_channel_quant, block_shape, + block_m) + self.deep_gemm_expert = DeepGemmExperts() + self.allow_deep_gemm = allow_deep_gemm + self.use_fp8_w8a8 = use_fp8_w8a8 + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + return self.deep_gemm_expert.workspace_shapes( + a, M, N, K, topk, num_experts) + else: + return self.triton_expert.workspace_shapes(a, M, N, K, topk, + num_experts) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + N = w1.shape[1] + if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + return self.deep_gemm_expert( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) + else: + return self.triton_expert( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index db31422f727..d53da1d7926 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -7,6 +7,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) from vllm.utils import cdiv @@ -15,34 +17,80 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel() + assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" return x.flatten()[:prod(v)].view(*v) def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], - block_shape: Optional[List[int]], + per_act_token: bool, + block_shape: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape is provided, the output will be blocked. """ if block_shape is None: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) + A, A_scale = ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_act_token) else: assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + + return A, A_scale + + +def _int8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform int8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + + # If weights are per-channel (per_channel_quant=True), then + # activations apply per-token quantization. Otherwise, assume + # activation tensor-wise fp8/int8 quantization, dynamic or static + if block_shape is None: + assert per_act_token, \ + "int8 quantization only supports block or channel-wise" + A, A_scale = per_token_quant_int8(A) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale +def moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + qtype: Optional[torch.dtype], + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if qtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) + elif qtype == torch.int8: + return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + else: + assert A_scale is None + return A, A_scale + + def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. """ - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + if torch.is_floating_point(m) and m.dtype.itemsize == 1: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5515ba27ea1..acf43284a93 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util from typing import Any, Callable, Dict, List, Optional @@ -9,6 +10,7 @@ from torch.nn.parameter import Parameter import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -441,6 +443,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): """ def __init__(self, quant_config: Fp8Config): + from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None @@ -457,6 +460,11 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") + self.fused_experts = functools.partial( + fused_experts, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm) + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -768,6 +776,25 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return + def set_dispatch_combine( + self, + dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) + + experts = TritonOrDeepGemmExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + ) + + self.fused_experts = mk.FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + return True + def apply( self, layer: torch.nn.Module, @@ -786,8 +813,6 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -801,10 +826,10 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, @@ -819,8 +844,6 @@ def apply( if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ffa5840b460..2e12095d3ae 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,8 +32,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -143,7 +142,13 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + # When just tensor-parallel is used, it isn't required + # to reduce the shared_output result. Instead we reduce + # at the end of the forward pass. + # With EP and the pplx kernels - this is no longer viable + # as all GPU ranks in DP, produce the complete set of hidden_states. + # Therefore reduce the shared experts early. + reduce_results=self.experts.must_reduce_shared_outputs(), prefix=f"{prefix}.shared_experts", ) @@ -154,6 +159,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + if hidden_states.dtype != torch.float16: final_hidden_states = self.experts( hidden_states=hidden_states, @@ -171,9 +177,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # See DeepseekV2DecoderLayer for more details. final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) + if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -459,6 +465,7 @@ def __init__( o_proj=self.o_proj, ) + self.prefix = prefix self.debug_layer_idx = int(self.prefix.split(".")[-2]) @@ -480,7 +487,6 @@ def forward( k_pe, output_shape=hidden_states.shape) - class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -558,6 +564,7 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -576,6 +583,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0fdc30f36f9..68e427d272c 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -102,7 +102,7 @@ def forward(self, hidden_states): experts_out = routed_out + shared_out if self.tp_size > 1: - experts_out = tensor_model_parallel_all_reduce(experts_out) + experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(experts_out) return experts_out diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 47d90919ed8..aca0d658d88 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -156,7 +156,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 97acbaa2ac3..904dc2c452f 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -137,7 +137,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) final_hidden_states = final_hidden_states if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd3be901f4c..8150d8900bb 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -901,7 +901,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens