diff --git a/CMakeLists.txt b/CMakeLists.txt index 272bdb13c7..cf9a5c31fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,12 +58,14 @@ ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp) ascendc_library(vllm_ascend_kernels SHARED ${KERNEL_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp ) message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") file(GLOB VLLM_ASCEND_SRC -${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp) +${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp +${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp) include_directories( ${pybind11_INCLUDE_DIRS} @@ -73,6 +75,7 @@ include_directories( ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform ${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform + ${CMAKE_CURRENT_SOURCE_DIR}/csrc ) set( diff --git a/csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h b/csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h new file mode 100644 index 0000000000..597545872c --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h @@ -0,0 +1,123 @@ +#include +#include +#include "acl/acl.h" +#include "kernel_tiling/kernel_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_data.h" +#include "common_tiling.h" + + +namespace bmm_trans { +using namespace pp_matmul; + +std::unordered_map quantModeMap = { + {"per_channel_symm", 0}, + {"per_channel_asymm", 1}, + {"per_token_symm", 2}, +}; + +std::unordered_map formatModeMap = { + {"ND", 0}, + {"NZ", 1}, +}; + +std::unordered_map atType2tensorDType = { + {at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16}, + {at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}}; + +// batch size -> memory index +constexpr uint32_t MAX_CAPTURE_NUM = 1024; + +template +inline int GetModeVal(const MapType &mode_map, c10::optional mode_opt, c10::string_view default_mode, + const char *mode_name) +{ + std::string modeStr(mode_name); + c10::string_view mode_str = mode_opt.value_or(default_mode); + auto it = mode_map.find(mode_str); + // if input mode is unsupported, use default value + TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str)); + return it->second; +} + +std::tuple batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + auto tensorAShape = tensor_a.sizes(); + auto tensorBShape = tensor_b.sizes(); + auto tensorCShape = tensor_c.sizes(); + uint32_t n; + uint32_t block_dim; + + //auto &platform = PlatformInfo::Instance(); + HardwareInfo hwInfo; + std::map dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}}; + + at::ScalarType aType = tensor_a.scalar_type(); + at::ScalarType bType = tensor_b.scalar_type(); + at::ScalarType cType = tensor_c.scalar_type(); + TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same"); + TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half), + "tensor type only support half or bf16"); + + TensorFormat formatMode = static_cast(GetModeVal(formatModeMap, format_mode, "ND", "format_mode")); + MatMul::QuantMode quantMode = + static_cast(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode")); + + TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor"); + if (formatMode == TensorFormat::TENSOR_FORMAT_ND) { + TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format"); + TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong"); + n = tensorBShape[2]; + } else { + TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format"); + TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong"); + n = tensorBShape[1] * tensorBShape[3]; + } + TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong"); + + OpShape opShape = {.batchSize = static_cast(tensorAShape[1]), + .m = static_cast(tensorAShape[0]), + .k = static_cast(tensorAShape[2]), + .n = n}; + pp_matmul::PpMatmulTilingData matmulTilingData = { + .opShape = opShape, + }; + auto dType = atType2tensorDType[aType]; + MatMulInfo mmInfo = {.batchSize = opShape.batchSize, + .m = opShape.m, + .k = opShape.k, + .n = opShape.n, + .dtypeA = dType, + .dtypeB = dType, + .dtypeC = dType, + .formatB = formatMode, + .mmType = MatMul::MatMulType::MATMUL_EIN_SUM, + .inDtype = dTypeMap[aType], + .outDtype = dTypeMap[cType], + .quantMode = quantMode}; + GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData); + host_utils::PpMatmulTilingCheck(matmulTilingData); + + // tiling + int32_t batchIdx = opShape.m - 1; + uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData); + static auto global_tiling_data = at::empty( + {tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device())); + if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) { + aclrtMemcpy(global_tiling_data.data_ptr() + (tilingSize * batchIdx), tilingSize, &matmulTilingData, + tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); + } else { + // Handle the case where batchIdx is out of range + TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx); + } + at::Tensor tiling_tensor = + at::from_blob(global_tiling_data.data_ptr() + (tilingSize * batchIdx), tilingSize, at::kByte); + + return std::make_tuple(tiling_tensor, block_dim); + +} + +} + diff --git a/csrc/batch_matmul_transpose/op_host/common.h b/csrc/batch_matmul_transpose/op_host/common.h new file mode 100644 index 0000000000..82abd10e95 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/common.h @@ -0,0 +1,57 @@ + +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_COMMON_H +#define UTILS_COMMON_H + +namespace host_utils { + +constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4; +constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8; + +inline uint64_t alinInt64Count(uint64_t count) +{ + return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64; +} + +inline uint64_t alinInt32Count(uint64_t count) +{ + return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32; +} + +template +inline T CeilDiv(const T dividend, const T divisor) +{ + if (divisor == 0) { + return UINT32_MAX; + } + return (dividend + divisor - 1) / divisor; +} + +template +inline T RoundUp(const T val, const T align = 16) +{ + if (align == 0 || val + align - 1 < val) { + return 0; + } + return (val + align - 1) / align * align; +} + +template +inline T RoundDown(const T val, const T align = 16) +{ + if (align == 0) { + return 0; + } + return val / align * align; +} +} // namespace host_utils +#endif // UTILS_COMMON_H diff --git a/csrc/batch_matmul_transpose/op_host/common_tiling.h b/csrc/batch_matmul_transpose/op_host/common_tiling.h new file mode 100644 index 0000000000..4fac5c5bfa --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/common_tiling.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COMMMON_TILING_H +#define COMMMON_TILING_H + +#include +#include +#include "common.h" +#include "tiling/platform/platform_ascendc.h" + +namespace host_utils { + +constexpr uint32_t FP16_SIZE = 2; +constexpr uint32_t FP32_SIZE = 4; +constexpr uint32_t BLOCK_SIZE = 16; +constexpr uint32_t BLOCK_SIZE_INT8_K = 32; +constexpr uint32_t BASE_BLOCK_STEP = 2; +constexpr uint32_t AXES_ALIGN_SIZE = 512; +constexpr uint32_t AXES_ALIGN_SIZE_INT8 = 256; +constexpr uint32_t ND_SHAPE_SIZE = 2; +constexpr uint32_t NZ_SHAPE_SIZE = 4; +constexpr uint32_t CUBE_BLOCK_SIZE = 256; +constexpr uint32_t CUBE_BLOCK_SIZE_INT8 = 512; +constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN = 262144; +constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_INT8 = 131072 * 2; // 256 KB +constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_FP16 = 131072; // 128 KB +constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN_INT8_SPARSE = 160 * 1024; +constexpr uint32_t UB_LIMIT_SIZE_910A = 128 * 1024; + +enum class PlatformType { ASCEND_310P, ASCEND_910A, ASCEND_910B, ASCEND_910C, PLATFORM_INVALID }; + +struct PlatformInfo { +public: + static const PlatformInfo &Instance() + { + static PlatformInfo platformInfo; + return platformInfo; + } + + PlatformType socType; + uint32_t coreNum; + uint32_t coreNumAic; + uint32_t coreNumAiv; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l2Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + +private: + PlatformInfo() + { + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(); + // TODO Hard coding set to 910_93xx, parse using aclrtGetSocName is better + socType = PlatformType::ASCEND_910C; + coreNum = ascendcPlatform->GetCoreNum(); + coreNumAic = ascendcPlatform->GetCoreNumAic(); + coreNumAiv = ascendcPlatform->GetCoreNumAiv(); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1Size); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L2, l2Size); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0aSize); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0bSize); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0cSize); + } + + PlatformInfo(const PlatformInfo &) = delete; + PlatformInfo &operator=(const PlatformInfo &) = delete; + PlatformInfo(PlatformInfo &&) = delete; + PlatformInfo &operator=(PlatformInfo &&) = delete; +}; + +inline __attribute__((always_inline)) uint32_t GetN0TilingLimit(bool compressFlag, uint32_t tilingN, + const PlatformType &platformType) +{ + if (compressFlag) { + return std::min(tilingN * BLOCK_SIZE, AXES_ALIGN_SIZE_INT8); + } else { + return (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A) + ? AXES_ALIGN_SIZE + : AXES_ALIGN_SIZE_INT8; + } +} + +template +inline __attribute__((always_inline)) uint32_t GetN0TilingInit(const OpShareType &opShape, bool compressFlag, + uint32_t tilingN) +{ + const uint32_t rnd = 16; + return compressFlag + ? ((tilingN * BLOCK_SIZE > opShape.n) ? RoundUp(opShape.n, rnd) : tilingN * BLOCK_SIZE) + : BLOCK_SIZE; +} + +template +inline __attribute__((always_inline)) bool IsExceedTilingLimit(uint32_t axes0, uint32_t priAxes0, + uint32_t n0TilingLimit, PlatformType platformType, + uint32_t basicBlockSize) +{ + return (PRI_FLAG && axes0 > n0TilingLimit) || (!PRI_FLAG && priAxes0 > n0TilingLimit) || + (platformType == PlatformType::ASCEND_910A && basicBlockSize > UB_LIMIT_SIZE_910A); +} + +template +inline __attribute__((always_inline)) void SetOpShapeAxesInfo(OpShareType &opShape, uint32_t priAxes0, uint32_t axes0) +{ + opShape.m0 = PRI_FLAG ? priAxes0 : axes0; + opShape.n0 = PRI_FLAG ? axes0 : priAxes0; +} + +template +inline __attribute__((always_inline)) float CostFunc(const HardwareType &hwInfor, OpShapeType &shape) +{ + float aCoef = 1; + float bCoef = 1; + float bwCoef = static_cast(hwInfor.l2BandWidth) / static_cast(hwInfor.hbmBandWidth); + uint32_t mLoop = CeilDiv(shape.m, shape.m0); + uint32_t nLoop = CeilDiv(shape.n, shape.n0); + if (mLoop == 0 || nLoop == 0) { + return 1; + } + uint32_t coreNeed = shape.batchSize * mLoop * nLoop; + uint32_t blockDim = std::min(coreNeed, hwInfor.coreNum); + uint32_t mOnce = blockDim < nLoop ? shape.m0 : blockDim / nLoop * shape.m0; + uint32_t nOnce = blockDim < nLoop ? hwInfor.coreNum * shape.n0 : shape.n; + if (mOnce * shape.k * FP16_SIZE > hwInfor.l2Size) { + aCoef = bwCoef; + } + if (nOnce * shape.k * FP16_SIZE > hwInfor.l2Size) { + bCoef = bwCoef; + } + return 1 / (aCoef * static_cast(shape.n0)) + 1 / (bCoef * static_cast(shape.m0)); +} + +template +void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareType &hwInfor, + const MatMulInfoType &mmInfo, bool compressFlag = false, const uint32_t tilingN = 1) +{ + float costMin = 1; + const float CONST_2 = 2.0; + const uint32_t ROUND_CONST_16 = 16; + uint32_t roundBase = static_cast( + pow(2, ceil(log(CeilDiv(PRI_FLAG ? opShape.n : opShape.m, ROUND_CONST_16)))) * ROUND_CONST_16); + uint32_t priAxes = RoundUp(PRI_FLAG ? opShape.m : opShape.n, ROUND_CONST_16); + uint32_t axes = RoundUp(PRI_FLAG ? opShape.n : opShape.m, roundBase); + float axes0Max = static_cast(AXES_ALIGN_SIZE) / mmInfo.inDtype; + auto platformType = PlatformInfo::Instance().socType; + if (mmInfo.isInt8 && (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A)) { + axes0Max /= CONST_2; + } + + uint32_t n0TilingInit = GetN0TilingInit(opShape, compressFlag, tilingN); + uint32_t n0TilingLimit = GetN0TilingLimit(compressFlag, tilingN, platformType); + uint32_t priAxes0Init = PRI_FLAG ? BLOCK_SIZE : n0TilingInit; + uint32_t axes0Init = PRI_FLAG ? n0TilingInit : BLOCK_SIZE; + for (uint32_t priAxes0 = priAxes0Init; priAxes0 <= priAxes && priAxes0 <= axes0Max; priAxes0 *= BASE_BLOCK_STEP) { + for (uint32_t axes0 = axes0Init; axes0 <= axes && axes0 <= axes0Max; axes0 *= BASE_BLOCK_STEP) { + uint32_t basicBlockSize = priAxes0 * axes0 * FP32_SIZE; + if (basicBlockSize > hwInfor.l0cSize) { + continue; + } + if (mmInfo.isInt8 && + IsExceedTilingLimit(axes0, priAxes0, n0TilingLimit, platformType, basicBlockSize)) { + continue; + } + SetOpShapeAxesInfo(opShape, priAxes0, axes0); + float cost = CostFunc(hwInfor, opShape); + if (cost >= costMin) { + continue; + } + costMin = cost; + if constexpr (std::is_same::value) { + tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0, mmInfo); + } else { + tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0); + } + } + } +} + +template +uint32_t Swizzl(PpTilingDataType &tilingData) +{ + uint32_t swizzlDirect = 0; + uint32_t swizzlCount = 1; + float m0 = tilingData.opShape.m0; + float n0 = tilingData.opShape.n0; + float m = tilingData.opShape.m; + float k = tilingData.opShape.k; + float n = tilingData.opShape.n; + float mincost = m * k + k * n; + + for (uint32_t i = 1; i <= tilingData.blockDim; ++i) { + int c = static_cast((tilingData.blockDim + i - 1) / i); + float cost; + // B0 + A < A0 + B + if (i * n0 + m < m0 * c + n) { + swizzlDirect = 1; // Nz + cost = n0 * i + m0 * c; + if (cost <= mincost) { + mincost = cost; + swizzlCount = i; + } + } else { + swizzlDirect = 0; // Zn + cost = m0 * i + n0 * c; + if (cost < mincost) { + mincost = cost; + swizzlCount = i; + } + } + } + tilingData.swizzlDirect = swizzlDirect; + tilingData.swizzlCount = swizzlCount; + return swizzlDirect; +} + +template +inline __attribute__((always_inline)) void PpMatmulTilingCheck(const PpTilingDataType &tilingData) +{ + TORCH_CHECK(tilingData.opShape.m0 > 0, "m0 is invalid"); + TORCH_CHECK(tilingData.opShape.k0 > 0, "k0 is invalid"); + TORCH_CHECK(tilingData.opShape.n0 > 0, "n0 is invalid"); + TORCH_CHECK(tilingData.mLoop > 0, "mLoop is invalid"); + TORCH_CHECK(tilingData.kLoop > 0, "kLoop is invalid"); + TORCH_CHECK(tilingData.nLoop > 0, "nLoop is invalid"); + TORCH_CHECK(tilingData.blockDim > 0, "nLoop is invalid"); +} +} // namespace host_utils +#endif diff --git a/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp new file mode 100644 index 0000000000..5c606e5026 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp @@ -0,0 +1,155 @@ +#include +#include "tiling_data.h" +#include "batch_matmul_transpose/op_host/common.h" +#include "batch_matmul_transpose/op_host/common_tiling.h" + +namespace pp_matmul { + +constexpr uint32_t L1_DESCALE_BUFFER_LEN_MAX = 6144; +constexpr uint32_t CONST_3 = 3; +constexpr uint32_t CONST_4 = 4; +constexpr uint32_t CONST_16 = 16; +constexpr uint32_t CONST_32 = 32; +constexpr uint32_t CONST_256 = 256; +constexpr uint32_t CONST_512 = 512; + +const std::map G_DTYPE_MAP = {{TensorDType::TENSOR_DTYPE_FLOAT16, 1u}, + {TensorDType::TENSOR_DTYPE_BF16, 2u}}; +const std::map G_FORMAT_MAP = {{TensorFormat::TENSOR_FORMAT_ND, 0u}, + {TensorFormat::TENSOR_FORMAT_NZ, 1u}}; +using MmType = MatMul::MatMulType; +using QmType = MatMul::QuantMode; +using namespace host_utils; + +bool IsI8Bf16Kernel(const MatMulInfo &mmInfo) +{ + bool isI8Bf16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_BF16; + bool isI8Fp16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_FLOAT16 && + mmInfo.quantMode == QmType::PER_TOKEN_SYMM; + return isI8Bf16 || isI8Fp16; +} + +HardwareInfo::HardwareInfo() +{ + auto &platform = PlatformInfo::Instance(); + coreNum = platform.coreNumAic; + l2Size = platform.l2Size; + l1Size = platform.l1Size; + l0aSize = platform.l0aSize; + l0bSize = platform.l0bSize; + l0cSize = platform.l0cSize; + hbmBandWidth = 1; + l2BandWidth = 5; // 5x faster than hbm. +} + +void PpMatmulTilingData::SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n) +{ + opShape.batchSize = batchSize; + opShape.m = m; + opShape.k = k; + opShape.n = n; +} + +void PpMatmulTilingData::SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo) +{ + opShape.m0 = mBase; + opShape.n0 = nBase; + mLoop = CeilDiv(opShape.m, opShape.m0); + nLoop = CeilDiv(opShape.n, opShape.n0); + coreLoop = opShape.batchSize * mLoop * nLoop; + + if (mLoop == 1 && mmInfo.transB && coreLoop % coreNum < coreNum / CONST_4 * CONST_3) { + mBase = RoundUp(opShape.m, CONST_16); + opShape.m0 = mBase; + uint32_t maxN0 = PlatformInfo::Instance().l0cSize / (mBase * sizeof(float)); + if (mmInfo.isInt8 || mmInfo.mmType == MmType::MATMUL_WITH_BIAS) { + maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256; + } + uint32_t x = CeilDiv(opShape.n, coreNum); + uint32_t y = CeilDiv(x, maxN0); + nBase = RoundUp(CeilDiv(x, y), CONST_16); + uint32_t rqdL0CSize = mBase * nBase * sizeof(float); + if (rqdL0CSize < PlatformInfo::Instance().l0cSize && + (mBase + nBase) * CONST_256 * sizeof(uint16_t) < L1AB_PINGPONG_BUFFER_LEN) { + opShape.n0 = nBase; + nLoop = CeilDiv(opShape.n, opShape.n0); + coreLoop = opShape.batchSize * nLoop; + } + } + blockDim = std::min(coreLoop, coreNum); +} + +// transA transB quantMode [dtype] format +void PpMatmulTilingData::SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK) +{ + if (mmInfo.mmType == MmType::MATMUL_ACCUM_ATOMIC || mmInfo.mmType == MmType::MATMUL_WITH_BIAS || + mmInfo.mmType == MmType::MATMUL_EIN_SUM || mmInfo.mmType == MmType::MATMUL_DEQUANT || IsI8Bf16Kernel(mmInfo)) { + // SwizzleDir[1] TransA[1] TransB[1] DtypeA[3] DtypeB[3] DtypeC[3] FormatA[1] FormatB[1] FormatC[1] WithBias[1] + tilingKey = swizzleDirect; + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transA); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transB); + tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeA); // 3bit for dtypeA. + tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeB); // 3bit for dtypeB. + tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeC); // 3bit for dtypeC. + tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatA); + tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatB); + tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatC); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.biasFlag); + } else { + tilingKey = swizzleDirect; + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transA); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transB); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.isInt8); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.biasFlag); + tilingKey = (tilingKey << 1) + enSplitK; + } +} + +uint32_t PpMatmulTilingData::End(const MatMulInfo &mmInfo) +{ + uint32_t cubeBlockSize = mmInfo.isInt8 ? CUBE_BLOCK_SIZE_INT8 : CUBE_BLOCK_SIZE; + uint32_t kBlockSize = mmInfo.isInt8 ? BLOCK_SIZE_INT8_K : BLOCK_SIZE; + uint32_t scaleBlockSize = mmInfo.isInt8 ? L1_DESCALE_BUFFER_LEN_MAX : 0; + uint32_t shapeSum = opShape.m0 + opShape.n0; + if (mmInfo.isInt8 && (mmInfo.transA || !mmInfo.transB)) { + shapeSum = RoundUp(opShape.m0, CONST_32) + RoundUp(opShape.n0, CONST_32); + } + uint32_t k0Max = shapeSum == 0 + ? L1AB_PINGPONG_BUFFER_LEN + : static_cast(static_cast(L1AB_PINGPONG_BUFFER_LEN - scaleBlockSize) / + (shapeSum * mmInfo.inDtype)); + if (mmInfo.mmType == MatMul::MatMulType::MATMUL_WITH_BIAS) { + uint32_t l1AbSize = L1AB_PINGPONG_BUFFER_LEN - opShape.n0 * sizeof(float); + k0Max = l1AbSize / (shapeSum * mmInfo.inDtype); + } + + opShape.k0 = + k0Max < cubeBlockSize ? RoundDown(k0Max, kBlockSize) : RoundDown(k0Max, cubeBlockSize); + if (opShape.k0 > CONST_512) { + opShape.k0 = RoundDown(opShape.k0, CONST_512); + } + kLoop = CeilDiv(opShape.k, opShape.k0); + return blockDim; +} + +void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim, + PpMatmulTilingData &tilingData) +{ + OpShape opShape; + opShape.batchSize = mmInfo.batchSize; + opShape.m = mmInfo.m; + opShape.n = mmInfo.n; + opShape.k = mmInfo.k; + tilingData.opShape = opShape; + tilingData.quantMode = static_cast(mmInfo.quantMode); + tilingData.SetTilingKey(mmInfo, 0, 0); // init tilingkey with transA transB. + if (opShape.m < opShape.n) { + TilingFunc(opShape, tilingData, hwInfo, mmInfo); + } else { + TilingFunc(opShape, tilingData, hwInfo, mmInfo); + } + uint32_t direct = Swizzl(tilingData); + blockDim = tilingData.End(mmInfo); + tilingData.SetTilingKey(mmInfo, direct, 0); +} +} // namespace pp_matmul diff --git a/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h new file mode 100644 index 0000000000..5c86ff59fe --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h @@ -0,0 +1,90 @@ +#ifndef PP_MATMUL_TILING_DATA +#define PP_MATMUL_TILING_DATA +#include + +namespace pp_matmul { +struct MatMul { + enum class MatMulType : uint32_t { + MATMUL_DEFAULT = 0, // C = op(A) * op(B) + MATMUL_DEQUANT, // + MATMUL_ACCUM_ATOMIC, // C += op(A) * op(B) + MATMUL_WITH_BIAS, // C = op(A) * op(B) + Bias, where Bias is a vector. + MATMUL_EIN_SUM + }; + enum class QuantMode : uint32_t { PER_CHANNEL_SYMM = 0, PER_CHANNEL_ASYMM, PER_TOKEN_SYMM }; +}; + +enum class TensorDType : uint32_t { TENSOR_DTYPE_FLOAT16 = 0, TENSOR_DTYPE_BF16 }; + +enum class TensorFormat : uint32_t { TENSOR_FORMAT_ND = 0, TENSOR_FORMAT_NZ }; + +struct MatMulInfo { + uint32_t batchSize{0}; + uint32_t m{0}; // actual input m + uint32_t k{0}; // actual input k + uint32_t n{0}; // actual input n + TensorDType dtypeA{TensorDType::TENSOR_DTYPE_FLOAT16}; + TensorDType dtypeB{TensorDType::TENSOR_DTYPE_FLOAT16}; + TensorDType dtypeC{TensorDType::TENSOR_DTYPE_FLOAT16}; + TensorFormat formatA{TensorFormat::TENSOR_FORMAT_ND}; + TensorFormat formatB{TensorFormat::TENSOR_FORMAT_ND}; + TensorFormat formatC{TensorFormat::TENSOR_FORMAT_ND}; + MatMul::MatMulType mmType{MatMul::MatMulType::MATMUL_DEFAULT}; + bool transA{0}; // false: 0, true: 1 + bool transB{0}; // false: 0, true: 1 + bool biasFlag{0}; // false: 0, true: 1 + bool isInt8{0}; // 是否是 int8融合 + float inDtype{0}; + float outDtype{0}; + MatMul::QuantMode quantMode{MatMul::QuantMode::PER_CHANNEL_SYMM}; +}; + +struct OpShape { + uint32_t batchSize{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; +}; + +struct HardwareInfo { + uint32_t coreNum{0}; + uint32_t l2Size{0}; + uint32_t l1Size{0}; + uint32_t l0aSize{0}; + uint32_t l0bSize{0}; + uint32_t l0cSize{0}; + uint32_t hbmBandWidth{0}; + uint32_t l2BandWidth{0}; + + HardwareInfo(); +}; + +#pragma pack(push, 1) +struct PpMatmulTilingData { + OpShape opShape{}; + uint32_t mLoop{1}; + uint32_t kLoop{1}; + uint32_t nLoop{1}; + uint32_t coreLoop{1}; + uint32_t swizzlCount{1}; + uint32_t tilingKey{0}; + uint32_t blockDim{1}; + uint32_t swizzlDirect{0}; + uint32_t splitk{0}; + uint32_t enShuffleK{0}; + uint32_t quantMode{0}; + + void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n); + void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo); + void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK); + uint32_t End(const MatMulInfo &mmInfo); +}; +#pragma pack(pop) + +void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim, + PpMatmulTilingData &tilingData); +} // namespace pp_matmul +#endif diff --git a/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp b/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp new file mode 100644 index 0000000000..81d987bae6 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp @@ -0,0 +1,824 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#define __aicore__ [aicore] +#include "kernel_operator.h" +#include "../op_host/tiling/tiling_data.h" +#include "../../mla_preprocess/op_kernel/kernel/common.h" +#include "../../mla_preprocess/op_kernel/kernel/hardware.h" +#include "../../mla_preprocess/op_kernel/kernel/mma.h" +#include "../../mla_preprocess/op_kernel/kernel/utils.h" +#include "../../mla_preprocess/op_kernel/kernel/iterator.h" +#include "../../kernels/math_utils.h" + +constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; +constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; +constexpr uint32_t CONST_16 = 16; +constexpr uint32_t CONST_256 = 256; +constexpr uint64_t ND2NZ_STRIDE_LIMIT = 65536; +constexpr uint64_t BLOCK_SIZE_16 = 16; +constexpr uint64_t CONST_16UL = 16; +constexpr uint64_t CONST_256UL = 256; + +struct MatCoord { + uint64_t m{0}; + uint64_t k{0}; + uint64_t n{0}; +}; + +using namespace device_utils; + +template +class PpMatmulEinSum +{ + using LocalTensor = AscendC::LocalTensor; + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mad = mmad; + using CopyCcToGm = l0c_to_gm; + +public: + __aicore__ explicit PpMatmulEinSum(){}; + + __aicore__ __force_inline__ void Init(__gm__ uint8_t *__restrict__ a, __gm__ uint8_t *__restrict__ b, + __gm__ uint8_t *__restrict__ c, __gm__ uint8_t *__restrict__ tiling_data) + { + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(a)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(b)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(c)); + auto gm_tiling_data = reinterpret_cast<__gm__ pp_matmul::PpMatmulTilingData *>(tiling_data); + + batch_size = gm_tiling_data->opShape.batchSize; + m = gm_tiling_data->opShape.m; + k = gm_tiling_data->opShape.k; + n = gm_tiling_data->opShape.n; + m0 = gm_tiling_data->opShape.m0; + k0 = gm_tiling_data->opShape.k0; + n0 = gm_tiling_data->opShape.n0; + tdim.m = gm_tiling_data->mLoop; + tdim.k = gm_tiling_data->kLoop; + tdim.n = gm_tiling_data->nLoop; + core_loop = gm_tiling_data->coreLoop; + swizzle_cnt = gm_tiling_data->swizzlCount; + en_shuffle_k = gm_tiling_data->enShuffleK; + + AsdopsBuffer buf; + l1_base_a = buf.template GetBuffer(0); + l1_base_b = buf.template GetBuffer( + RoundUp(m0 * k0 * sizeof(InDtype), CONST_256UL)); + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); + num_core = AscendC::GetBlockNum(); + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + } + + __aicore__ __force_inline__ void GetBlockIdx(uint64_t index, MatCoord &tidx) + { + uint64_t in_batch_idx = index % (tdim.m * tdim.n); + if constexpr (SwizzleDirect == 0) { // Zn + uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = tdim.m - swizzle_cnt * tile_block_idx; + } + tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + tidx.n = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + tidx.n = tdim.n - tidx.n - 1; + } + } else if constexpr (SwizzleDirect == 1) { // Nz + uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = tdim.n - swizzle_cnt * tile_block_idx; + } + tidx.m = in_tile_block_idx / n_col; + tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + tidx.m = tdim.m - tidx.m - 1; + } + } + } + + __aicore__ __force_inline__ void Process() + { + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { + uint64_t batch_idx = loop_idx / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBlockIdx(loop_idx, tidx); + uint64_t offset_a = 0, offset_b = 0, offset_a_next = 0, offset_b_next = 0; + uint64_t offset_c = tidx.m * m0 * batch_size * n + batch_idx * n + tidx.n * n0; + uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round = RoundUp(m_actual); + uint64_t n_round = RoundUp(n_actual); + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + if (TA) { + offset_a = shuffle_k * k0 * m * batch_size + batch_idx * m + tidx.m * m0; + } else { + offset_a = tidx.m * m0 * batch_size * k + batch_idx * k + shuffle_k * k0; + } + + if (TB) { + if constexpr (FormatB != DataFormat::NZ) { + offset_b = batch_idx * k * n + tidx.n * n0 * k + shuffle_k * k0; + } else { + offset_b = batch_idx * RoundUp(k) * RoundUp(n) + + tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k * k0 * RoundUp(n); + } + } else { + if constexpr (FormatB != DataFormat::NZ) { + offset_b = batch_idx * k * n + shuffle_k * k0 * n + tidx.n * n0; + } else { + offset_b = batch_idx * RoundUp(k) * RoundUp(n) + + shuffle_k * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp(k); + } + } + + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l0a_buf = ping_flag ? l0a_base : l0a_base[L0_PINGPONG_BUFFER_LEN]; + LocalTensor l0b_buf = ping_flag ? l0b_base : l0b_base[L0_PINGPONG_BUFFER_LEN]; + event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (loop_idx == core_idx) { + WAIT_FLAG(MTE1, MTE2, event_id); + // *** load matrix A to L1 + if ((m == 1) || (m_actual == 1 && !TA)) { + CopyGmToCbuf(l1_buf_a, // dst + gm_a[offset_a], // src + 1, // nTileActual + 16, // nTileCeil + 1, // nVal + k_actual, // kTileActual + k_round, // kTileCeil + k); // dVal + } else { + if (TA) { + CopyGmToCbuf(l1_buf_a, // dst + gm_a[offset_a], // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + m_actual, // dTileActual + m_round, // dTileCeil + m * batch_size); // dVal + } else { + CopyGmToCbuf(l1_buf_a, // dst + gm_a[offset_a], // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k * batch_size); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id); + // *** load matrix B to L1 + wait_flag(PIPE_MTE1, PIPE_MTE2, event_id + 2); + if constexpr (FormatB != DataFormat::NZ) { + if (TB) { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if (TB) { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id + 2); + } + + for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { + shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; + uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + fdim.k = (k_actual + k_part_len - 1) / k_part_len; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (tidx.k < tdim.k - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); + if (TA) { + offset_a_next = shuffle_k_next * k0 * m * batch_size + batch_idx * m + tidx.m * m0; + } else { + offset_a_next = tidx.m * m0 * batch_size * k + batch_idx * k + shuffle_k_next * k0; + } + + if (TB) { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = batch_idx * k * n + tidx.n * n0 * k + shuffle_k_next * k0; + } else { + offset_b_next = + batch_idx * RoundUp(k) * RoundUp(n) + + tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k_next * k0 * RoundUp(n); + } + } else { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = batch_idx * k * n + shuffle_k_next * k0 * n + tidx.n * n0; + } else { + offset_b_next = + batch_idx * RoundUp(k) * RoundUp(n) + + shuffle_k_next * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp(k); + } + } + + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + // *** load matrix A to L1 + if ((m == 1) || (m_actual == 1 && !TA)) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual_next, // kTileActual + k_round_next, // kTileCeil + k); // dVal + } else { + if (TA) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + m_actual, // dTileActual + m_round, // dTileCeil + m * batch_size); // dVal + } else { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k * batch_size); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next); + + // *** load matrix B to L1 + wait_flag(PIPE_MTE1, PIPE_MTE2, event_id_next + 2); + if constexpr (FormatB != DataFormat::NZ) { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { + uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBlockIdx(loop_idx + num_core, tidx); + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; + uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + if (TA) { + offset_a_next = shuffle_k_next * k0 * m * batch_size + b_idx_next * m + tidx.m * m0; + } else { + offset_a_next = tidx.m * m0 * batch_size * k + b_idx_next * k + shuffle_k_next * k0; + } + + if (TB) { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = b_idx_next * k * n + tidx.n * n0 * k + shuffle_k_next * k0; + } else { + offset_b_next = + b_idx_next * RoundUp(k) * RoundUp(n) + + tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k_next * k0 * RoundUp(n); + } + } else { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = b_idx_next * k * n + shuffle_k_next * k0 * n + tidx.n * n0; + } else { + offset_b_next = + b_idx_next * RoundUp(k) * RoundUp(n) + + shuffle_k_next * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp(k); + } + } + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + // *** load matrix A to L1 + if (m == 1 || m_actual_next == 1 && !TA) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual_next, // nTileActual + m_round_next, // nTileCeil + m, // nVal + k_actual_next, // kTileActual + k_round_next, // kTileCeil + k); // dVal + } else { + if (TA) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + m_actual_next, // dTileActual + m_round_next, // dTileCeil + m * batch_size); // dVal + } else { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual_next, // nTileActual + m_round_next, // nTileCeil + m, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k * batch_size); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next); + + // *** load matrix B to L1 + wait_flag(PIPE_MTE1, PIPE_MTE2, event_id_next + 2); + if constexpr (FormatB != DataFormat::NZ) { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual_next, // nTileActual + n_round_next, // nTileCeil + n, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + n_actual_next, // dTileActual + n_round_next, // dTileCeil + n); // dVal + } + } else { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual_next, // nTileActual + n_round_next, // nTileCeil + RoundUp(n), // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + RoundUp(k), // nVal + n_actual_next, // dTileActual + n_round_next, // dTileCeil + RoundUp(n)); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + MatCoord fidx{0}; + for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { + uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; + uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; + + auto mte1_mad_ping_flag = 1 - fidx.k % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + auto l0a_buf = l0a_base[(fidx.k % 2) * L0_PINGPONG_BUFFER_LEN]; + auto l0b_buf = l0b_base[(fidx.k % 2) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1 && !TA)) { + l1_to_l0_a( + l0a_buf, // dst + l1_buf_a[fidx.k * k_part_len], // src + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + if (TA) { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * CONST_16], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // mSrcStride + 1, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / CONST_16, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + 2); + } + if (TB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / CONST_16, // kSrcStride + 1, // nDstStride + k0_round / CONST_16); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / CONST_16); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id + 2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (tidx.k == 0 && fidx.k == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + + if (m != 1 && m_actual == 1 && TA) { + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + CONST_16, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + } else { + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + m_actual, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + } + + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // mTileActual + n_actual, // nTileActual + m_round, // mTileCeil + n * batch_size); // nActual + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); + PIPE_BARRIER(ALL); + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint32_t num_core{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + MatCoord tdim{0}; + MatCoord fdim{0}; + uint32_t core_loop{0}; + uint32_t swizzle_cnt{1}; + uint32_t core_idx{0}; + uint32_t en_shuffle_k{0}; + uint32_t ping_flag{0}; +}; + +extern "C" __global__ __aicore__ void batch_matmul_transpose(GM_ADDR gm_a, GM_ADDR gm_b, GM_ADDR gm_c, + GM_ADDR gm_tiling_data) +{ + PpMatmulEinSum<0, false, false, half, half, DataFormat::ND> + einsum_0_n_fp16_nd; // swizzleDir[0] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, false, half, half, DataFormat::ND> + einsum_1_n_fp16_nd; // swizzleDir[1] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<0, false, true, half, half, DataFormat::ND> + einsum_0_t_fp16_nd; // swizzleDir[0] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, true, half, half, DataFormat::ND> + einsum_1_t_fp16_nd; // swizzleDir[1] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<0, false, false, __bf16, __bf16, DataFormat::ND> + einsum_0_n_bf16_nd; // swizzleDir[0] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, false, __bf16, __bf16, DataFormat::ND> + einsum_1_n_bf16_nd; // swizzleDir[1] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<0, false, true, __bf16, __bf16, DataFormat::ND> + einsum_0_t_bf16_nd; // swizzleDir[0] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, true, __bf16, __bf16, DataFormat::ND> + einsum_1_t_bf16_nd; // swizzleDir[1] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + + PpMatmulEinSum<0, false, false, half, half, DataFormat::NZ> + einsum_0_n_fp16_nz; // swizzleDir[0] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, false, half, half, DataFormat::NZ> + einsum_1_n_fp16_nz; // swizzleDir[1] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<0, false, true, half, half, DataFormat::NZ> + einsum_0_t_fp16_nz; // swizzleDir[0] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, true, half, half, DataFormat::NZ> + einsum_1_t_fp16_nz; // swizzleDir[1] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<0, false, false, __bf16, __bf16, DataFormat::NZ> + einsum_0_n_bf16_nz; // swizzleDir[0] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, false, __bf16, __bf16, DataFormat::NZ> + einsum_1_n_bf16_nz; // swizzleDir[1] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<0, false, true, __bf16, __bf16, DataFormat::NZ> + einsum_0_t_bf16_nz; // swizzleDir[0] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, true, __bf16, __bf16, DataFormat::NZ> + einsum_1_t_bf16_nz; // swizzleDir[1] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + + SetPadding((uint64_t)0); + SetNdpara(1, 0, 0); + SetAtomicnone(); + + // get tiling args + auto tiling_data = reinterpret_cast<__gm__ pp_matmul::PpMatmulTilingData *>(gm_tiling_data); + uint32_t masked_key = tiling_data->tilingKey >> 2; + + switch (masked_key) { + case 0b00000100100100: + case 0b01000100100100: + einsum_0_n_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_fp16_nd.Process(); + break; + case 0b00100100100100: + case 0b01100100100100: + einsum_0_t_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_fp16_nd.Process(); + break; + case 0b10000100100100: + case 0b11000100100100: + einsum_1_n_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_fp16_nd.Process(); + break; + case 0b10100100100100: + case 0b11100100100100: + einsum_1_t_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_fp16_nd.Process(); + break; + case 0b00001001001000: + case 0b01001001001000: + einsum_0_n_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_bf16_nd.Process(); + break; + case 0b00101001001000: + case 0b01101001001000: + einsum_0_t_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_bf16_nd.Process(); + break; + case 0b10001001001000: + case 0b11001001001000: + einsum_1_n_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_bf16_nd.Process(); + break; + case 0b10101001001000: + case 0b11101001001000: + einsum_1_t_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_bf16_nd.Process(); + break; + + case 0b00000100100101: + case 0b01000100100101: + einsum_0_n_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_fp16_nz.Process(); + break; + case 0b00100100100101: + case 0b01100100100101: + einsum_0_t_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_fp16_nz.Process(); + break; + case 0b10000100100101: + case 0b11000100100101: + einsum_1_n_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_fp16_nz.Process(); + break; + case 0b10100100100101: + case 0b11100100100101: + einsum_1_t_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_fp16_nz.Process(); + break; + case 0b00001001001001: + case 0b01001001001001: + einsum_0_n_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_bf16_nz.Process(); + break; + case 0b00101001001001: + case 0b01101001001001: + einsum_0_t_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_bf16_nz.Process(); + break; + case 0b10001001001001: + case 0b11001001001001: + einsum_1_n_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_bf16_nz.Process(); + break; + case 0b10101001001001: + case 0b11101001001001: + einsum_1_t_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_bf16_nz.Process(); + break; + default: + break; + } +} + + +namespace vllm_ascend { + +extern void batch_matmul_transpose_impl( + void* stream, + void* gm_a, + void* gm_b, + void* gm_c, + void* gm_tiling_data, + const uint32_t block_dim) +{ + batch_matmul_transpose<<>>( + gm_a, + gm_b, + gm_c, + gm_tiling_data); +} + +} \ No newline at end of file diff --git a/csrc/kernels/math_utils.h b/csrc/kernels/math_utils.h new file mode 100644 index 0000000000..62b46921c1 --- /dev/null +++ b/csrc/kernels/math_utils.h @@ -0,0 +1,15 @@ +#ifndef KERNEL_MATH_UTILS_H +#define KERNEL_MATH_UTILS_H +#include + +namespace device_utils { + +template +__aicore__ __force_inline__ T RoundUp(const T &val) +{ + return (val + roundVal - 1) / roundVal * roundVal; +} + +}; // namespace device_utils + +#endif diff --git a/csrc/ops.h b/csrc/ops.h index c249bb5875..2401792db0 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -158,4 +158,13 @@ namespace vllm_ascend { void* tiling, const uint32_t block_dim ); + + extern void batch_matmul_transpose_impl( + void* stream, + void* gm_a, + void* gm_b, + void* gm_c, + void* gm_tiling_data, + const uint32_t block_dim + ); } diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 9eaba72363..7cae5fc17d 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -24,6 +24,7 @@ #include "ops.h" #include "utils.h" #include "mla_preprocess/op_host/mla_preprocess.h" +#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h" namespace vllm_ascend { @@ -458,6 +459,39 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic cmd.Run(); return y_out; } + +void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling( + tensor_a, + tensor_b, + tensor_c, + format_mode, + quant_mode + ); + + void *gm_a = tensor_a.data_ptr(); + void *gm_b = tensor_b.data_ptr(); + void *gm_c = tensor_c.data_ptr(); + void *gm_tiling_data = tiling_tensor.data_ptr(); + + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("batch_matmul_transpose"); + + cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data, + block_dim]() -> int { + batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data, + block_dim); + return 0; + }); + cmd.Run(); + return; + +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -511,4 +545,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " Tensor q_out1, Tensor kv_cache_out1)" ); ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess); + + ops.def( + "batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()"); + ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index dbb056be89..eceb5aafae 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -114,6 +114,13 @@ std::tuple mla_preproces return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; } +void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + return; + +} } // namespace meta } // namespace vllm_ascend @@ -132,5 +139,7 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); + // batch_matmul_transpose + ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose); } } diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 177d91bc8a..1ee4e64055 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -563,21 +563,13 @@ def __init__( self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO def _v_up_proj(self, x): - if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536: - x = x.view(-1, self.num_heads, self.kv_lora_rank) - x = torch_npu.npu_transpose_batchmatmul(x, - self.W_UV, - perm_x1=[1, 0, 2], - perm_x2=[0, 1, 2], - perm_y=[1, 0, 2]) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - else: - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + x = x.view(-1, self.num_heads, self.kv_lora_rank) + b, _, _ = x.shape + res2 = torch.empty((b, self.num_heads, self.v_head_dim), + dtype=x.dtype, + device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res2) + x = res2.reshape(-1, self.num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe`