| 
 | 1 | +// Licensed under the BSD 3-Clause License  (the "License");  | 
 | 2 | +// you may not use this file except in compliance with the License.  | 
 | 3 | +// You may obtain a copy of the License at  | 
 | 4 | +//  | 
 | 5 | +// Unless required by applicable law or agreed to in writing, software  | 
 | 6 | +// distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 7 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 8 | +// See the License for the specific language governing permissions and  | 
 | 9 | +// limitations under the License.  | 
 | 10 | +#include "defines.h"  | 
 | 11 | +#include "build_tree_tiling.h"  | 
 | 12 | +#include "tiling/platform/platform_ascendc.h"  | 
 | 13 | +#include "aclrtlaunch_build_tree_efficient.h"  | 
 | 14 | +#include "torch_helper.h"  | 
 | 15 | + | 
 | 16 | +namespace sglang {  | 
 | 17 | +namespace npu_kernel {  | 
 | 18 | +constexpr uint32_t PADDING_BYTE = 32U;  | 
 | 19 | + | 
 | 20 | +at::Tensor get_tiling(int32_t &block_dim, int32_t &workspace_size, int32_t batch_size, int32_t mask_size,  | 
 | 21 | +    int64_t topk, int64_t depth, int64_t draft_token_num, int64_t tree_mask_mode)  | 
 | 22 | +{  | 
 | 23 | +    auto ascendc_platform = platform_ascendc::PlatformAscendCManager::GetInstance();  | 
 | 24 | +    int32_t max_aiv_core = static_cast<int32_t>(ascendc_platform->GetCoreNumAiv());  | 
 | 25 | +    block_dim = std::min(max_aiv_core, batch_size);  | 
 | 26 | +    workspace_size = static_cast<int32_t>(ascendc_platform->GetLibApiWorkSpaceSize());  | 
 | 27 | + | 
 | 28 | +    // align to 32 bytes  | 
 | 29 | +    int32_t tiling_size = (sizeof(BuildTreeTilingData) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;  | 
 | 30 | +    auto tiling_buffer = at::empty({tiling_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU));  | 
 | 31 | +      | 
 | 32 | +    BuildTreeTilingData *tiling_data = reinterpret_cast<BuildTreeTilingData *>(tiling_buffer.data_ptr());  | 
 | 33 | +    tiling_data->batch_size = batch_size;  | 
 | 34 | +    tiling_data->mask_size = mask_size;  | 
 | 35 | +    tiling_data->topk = topk;  | 
 | 36 | +    tiling_data->depth = depth;  | 
 | 37 | +    tiling_data->draft_token_num = draft_token_num;  | 
 | 38 | +    tiling_data->tree_mask_mode = tree_mask_mode;  | 
 | 39 | +      | 
 | 40 | +    auto num_big_core = batch_size % max_aiv_core;  | 
 | 41 | +    tiling_data->big_core_num = num_big_core == 0 ? block_dim : num_big_core;  | 
 | 42 | +    tiling_data->big_core_tile_num = (batch_size + num_big_core - 1) / num_big_core;  | 
 | 43 | +    tiling_data->small_core_tile_num = batch_size / num_big_core;  | 
 | 44 | + | 
 | 45 | +    auto tiling_tensor = TorchNpuHepler::CopyTensorHostToDevice(tiling_buffer);  | 
 | 46 | +    return tiling_tensor;  | 
 | 47 | +}  | 
 | 48 | + | 
 | 49 | +HOST_API void build_tree_efficient(const at::Tensor &parent_list,   | 
 | 50 | +    const at::Tensor &selected_index,  | 
 | 51 | +    const at::Tensor &verified_seq_len,   | 
 | 52 | +    const at::Tensor &tree_mask,   | 
 | 53 | +    const at::Tensor &positions,  | 
 | 54 | +    const at::Tensor &retrive_index,   | 
 | 55 | +    const at::Tensor &retrive_next_token,   | 
 | 56 | +    const at::Tensor &retrive_next_sibling,   | 
 | 57 | +    int64_t topk,   | 
 | 58 | +    int64_t depth,   | 
 | 59 | +    int64_t draft_token_num,   | 
 | 60 | +    int64_t tree_mask_mode)  | 
 | 61 | +{  | 
 | 62 | +    if (QLEN_ONLY_BITPACKING == tree_mask_mode) {  | 
 | 63 | +        throw std::runtime_error("Not implemented");  | 
 | 64 | +    }  | 
 | 65 | + | 
 | 66 | +    if (parent_list.options().dtype() != at::kLong   | 
 | 67 | +        || selected_index.options().dtype() != at::kLong  | 
 | 68 | +        || verified_seq_len.options().dtype() != at::kLong  | 
 | 69 | +        || tree_mask.options().dtype() != at::kBool   | 
 | 70 | +        || positions.options().dtype() != at::kLong  | 
 | 71 | +        || retrive_index.options().dtype() != at::kLong   | 
 | 72 | +        || retrive_next_token.options().dtype() != at::kLong  | 
 | 73 | +        || retrive_next_sibling.options().dtype() != at::kLong) {  | 
 | 74 | +        throw std::invalid_argument("Invaild input datetype. " \  | 
 | 75 | +            "Support combo: int64, int64, int64, bool, int64, int64, int64, int64");  | 
 | 76 | +    }  | 
 | 77 | +    int32_t block_dim;  | 
 | 78 | +    int32_t workspace_size;  | 
 | 79 | +    int32_t batch_size = parent_list.sizes()[0];  | 
 | 80 | +    int32_t mask_size = tree_mask.size(0);  | 
 | 81 | + | 
 | 82 | +    at::Tensor tiling_tensor = get_tiling(block_dim, workspace_size, batch_size, mask_size, topk, depth, draft_token_num,   | 
 | 83 | +        tree_mask_mode);  | 
 | 84 | + | 
 | 85 | +    auto workspace_tensor =   | 
 | 86 | +        at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(parent_list.options().device()));  | 
 | 87 | +    /* lauch the kernal function via torch */  | 
 | 88 | +    EXEC_KERNEL_CMD(build_tree_efficient, block_dim, parent_list, selected_index, verified_seq_len, tree_mask,   | 
 | 89 | +        positions, retrive_index, retrive_next_token, retrive_next_sibling, workspace_tensor, tiling_tensor);  | 
 | 90 | +}  | 
 | 91 | + | 
 | 92 | +}  | 
 | 93 | +}  | 
0 commit comments