diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7078c997578a..23bf7ba74101 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -33,26 +33,26 @@ def fused_moe_kernel( expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions - N, - K, - EM, - num_valid_tokens, + N: tl.int64, + K: tl.int64, + EM: tl.int64, + num_valid_tokens: tl.int64, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, + stride_am: tl.int64, + stride_ak: tl.int64, + stride_be: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, @@ -114,18 +114,16 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + off_experts = tl.load(expert_ids_ptr + pid_m) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: