Skip to content

Commit e44b310

Browse files
committed
[ops] build_tree
1 parent 29aec86 commit e44b310

File tree

6 files changed

+416
-0
lines changed

6 files changed

+416
-0
lines changed

csrc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ FILE(GLOB OP_SRCS
88
${PROJECT_OP_SRC_BASE}/cache_location_assign/op_host/cache_loc_assign.cpp
99
${PROJECT_OP_SRC_BASE}/alloc_extend/op_host/alloc_extend_tiling.cpp
1010
${PROJECT_OP_SRC_BASE}/assign_cache_op/op_host/assign_cache.cpp
11+
${PROJECT_OP_SRC_BASE}/build_tree/op_host/build_tree.cpp
1112
${PROJECT_OP_SRC_BASE}/mla_preprocess/op_host/mla_preprocess.cpp
1213
)
1314

@@ -24,6 +25,7 @@ ascendc_library(no_workspace_kernel STATIC
2425
ascendc_library(workspace_kernel STATIC
2526
${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
2627
${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp
28+
${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp
2729
)
2830

2931
ascendc_compile_definitions(workspace_kernel PRIVATE
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Licensed under the BSD 3-Clause License (the "License");
2+
// you may not use this file except in compliance with the License.
3+
// You may obtain a copy of the License at
4+
//
5+
// Unless required by applicable law or agreed to in writing, software
6+
// distributed under the License is distributed on an "AS IS" BASIS,
7+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8+
// See the License for the specific language governing permissions and
9+
// limitations under the License.
10+
#include "defines.h"
11+
#include "build_tree_tiling.h"
12+
#include "tiling/platform/platform_ascendc.h"
13+
#include "aclrtlaunch_build_tree_efficient.h"
14+
#include "torch_helper.h"
15+
16+
namespace sglang {
17+
namespace npu_kernel {
18+
constexpr uint32_t PADDING_BYTE = 32U;
19+
20+
at::Tensor get_tiling(int32_t &block_dim, int32_t &workspace_size, int32_t batch_size, int32_t mask_size,
21+
int64_t topk, int64_t depth, int64_t draft_token_num, int64_t tree_mask_mode)
22+
{
23+
auto ascendc_platform = platform_ascendc::PlatformAscendCManager::GetInstance();
24+
int32_t max_aiv_core = static_cast<int32_t>(ascendc_platform->GetCoreNumAiv());
25+
block_dim = std::min(max_aiv_core, batch_size);
26+
workspace_size = static_cast<int32_t>(ascendc_platform->GetLibApiWorkSpaceSize());
27+
28+
// align to 32 bytes
29+
int32_t tiling_size = (sizeof(BuildTreeTilingData) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
30+
auto tiling_buffer = at::empty({tiling_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU));
31+
32+
BuildTreeTilingData *tiling_data = reinterpret_cast<BuildTreeTilingData *>(tiling_buffer.data_ptr());
33+
tiling_data->batch_size = batch_size;
34+
tiling_data->mask_size = mask_size;
35+
tiling_data->topk = topk;
36+
tiling_data->depth = depth;
37+
tiling_data->draft_token_num = draft_token_num;
38+
tiling_data->tree_mask_mode = tree_mask_mode;
39+
40+
auto num_big_core = batch_size % max_aiv_core;
41+
tiling_data->big_core_num = num_big_core == 0 ? block_dim : num_big_core;
42+
tiling_data->big_core_tile_num = (batch_size + num_big_core - 1) / num_big_core;
43+
tiling_data->small_core_tile_num = batch_size / num_big_core;
44+
45+
auto tiling_tensor = TorchNpuHepler::CopyTensorHostToDevice(tiling_buffer);
46+
return tiling_tensor;
47+
}
48+
49+
HOST_API void build_tree_efficient(const at::Tensor &parent_list,
50+
const at::Tensor &selected_index,
51+
const at::Tensor &verified_seq_len,
52+
const at::Tensor &tree_mask,
53+
const at::Tensor &positions,
54+
const at::Tensor &retrive_index,
55+
const at::Tensor &retrive_next_token,
56+
const at::Tensor &retrive_next_sibling,
57+
int64_t topk,
58+
int64_t depth,
59+
int64_t draft_token_num,
60+
int64_t tree_mask_mode)
61+
{
62+
if (QLEN_ONLY_BITPACKING == tree_mask_mode) {
63+
throw std::runtime_error("Not implemented");
64+
}
65+
66+
if (parent_list.options().dtype() != at::kLong
67+
|| selected_index.options().dtype() != at::kLong
68+
|| verified_seq_len.options().dtype() != at::kLong
69+
|| tree_mask.options().dtype() != at::kBool
70+
|| positions.options().dtype() != at::kLong
71+
|| retrive_index.options().dtype() != at::kLong
72+
|| retrive_next_token.options().dtype() != at::kLong
73+
|| retrive_next_sibling.options().dtype() != at::kLong) {
74+
throw std::invalid_argument("Invaild input datetype. " \
75+
"Support combo: int64, int64, int64, bool, int64, int64, int64, int64");
76+
}
77+
int32_t block_dim;
78+
int32_t workspace_size;
79+
int32_t batch_size = parent_list.sizes()[0];
80+
int32_t mask_size = tree_mask.size(0);
81+
82+
at::Tensor tiling_tensor = get_tiling(block_dim, workspace_size, batch_size, mask_size, topk, depth, draft_token_num,
83+
tree_mask_mode);
84+
85+
auto workspace_tensor =
86+
at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(parent_list.options().device()));
87+
/* lauch the kernal function via torch */
88+
EXEC_KERNEL_CMD(build_tree_efficient, block_dim, parent_list, selected_index, verified_seq_len, tree_mask,
89+
positions, retrive_index, retrive_next_token, retrive_next_sibling, workspace_tensor, tiling_tensor);
90+
}
91+
92+
}
93+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Licensed under the BSD 3-Clause License (the "License");
2+
// you may not use this file except in compliance with the License.
3+
// You may obtain a copy of the License at
4+
//
5+
// Unless required by applicable law or agreed to in writing, software
6+
// distributed under the License is distributed on an "AS IS" BASIS,
7+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8+
// See the License for the specific language governing permissions and
9+
// limitations under the License.
10+
11+
#ifndef BUILD_TREE_TILING_H
12+
#define BUILD_TREE_TILING_H
13+
14+
#include <cstdint>
15+
namespace sglang {
16+
namespace npu_kernel {
17+
18+
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
19+
20+
struct BuildTreeTilingData {
21+
int64_t topk;
22+
int64_t depth;
23+
int64_t draft_token_num;
24+
int64_t tree_mask_mode;
25+
26+
int32_t batch_size;
27+
int32_t mask_size;
28+
29+
int32_t big_core_num;
30+
int32_t big_core_tile_num;
31+
int32_t small_core_tile_num;
32+
};
33+
34+
} // namespace npu_kernel
35+
} // namespace sglang
36+
37+
#endif // BUILD_TREE_TILING_H

0 commit comments

Comments
 (0)