|
19 | 19 | #include "graph/utils/type_utils.h" |
20 | 20 | #include "register/op_def_registry.h" |
21 | 21 | #include "../op_kernel/cam_moe_combine_normal_tiling.h" |
22 | | -#include "tiling_args.h" |
23 | 22 |
|
24 | 23 | using namespace AscendC; |
25 | 24 | using namespace ge; |
26 | | -using namespace Moe; |
27 | 25 |
|
28 | 26 | namespace { |
29 | 27 | class Mc2TilingUtils |
@@ -85,6 +83,7 @@ constexpr uint64_t MB_SIZE = 1024UL * 1024UL; |
85 | 83 | constexpr uint64_t TRIPLE = 3; |
86 | 84 | constexpr uint64_t WIN_ADDR_ALIGN = 512UL; |
87 | 85 | constexpr uint64_t SCALE_RECV_IDX_BUFFER = 44UL; // scale32B + 3*4 src info |
| 86 | +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL; |
88 | 87 | constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL; |
89 | 88 | constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL; |
90 | 89 | constexpr uint64_t UB_ALIGN = 32UL; |
@@ -515,17 +514,20 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext * |
515 | 514 | uint64_t maxBs = static_cast<uint64_t>(tilingData->camMoeCombineNormalInfo.globalBs) / epWorldSize; |
516 | 515 | // combine数据区 token首地址对齐512 |
517 | 516 | uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; |
| 517 | + // dispatch数据区 token首对齐512,有效token长度h_align_32b + scale(32b) + 三元组(3*4b) |
| 518 | + uint64_t tokenActualLen = ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_RECV_IDX_BUFFER; |
| 519 | + uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; |
518 | 520 | uint64_t actualSize = |
519 | | - (maxBs * k * tokenNeedSizeCombine + COMBINE_STATE_WIN_OFFSET + NOTIFY_DISPATCH_WIN_OFFSET) * DOUBLE_DATA_BUFFER; |
| 521 | + (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET) * DOUBLE_DATA_BUFFER; |
520 | 522 | OP_TILING_CHECK( |
521 | 523 | (actualSize > maxWindowSize), |
522 | 524 | OP_LOGE(nodeName, |
523 | 525 | "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u," |
524 | | - " tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE(" |
525 | | - "((maxBs * k * tokenNeedSizeCombine)) + 3MB + 204MB) * 2) = %luMB, " |
| 526 | + " tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE(" |
| 527 | + "((maxBs * tokenNeedSizeDispatch) + (maxBs * tokenNeedSizeCombine * k) + 3MB) * 2) = %luMB, " |
526 | 528 | "HCCL_BUFFSIZE=%luMB.", |
527 | | - maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL, |
528 | | - maxWindowSize / MB_SIZE), |
| 529 | + maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k, |
| 530 | + actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE), |
529 | 531 | return ge::GRAPH_FAILED); |
530 | 532 | tilingData->camMoeCombineNormalInfo.totalWinSize = maxWindowSize; |
531 | 533 |
|
|
0 commit comments