Skip to content

Commit 6fdc38c

Browse files
committed
Add build tree README.md
1 parent 5e748e8 commit 6fdc38c

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed

csrc/build_tree/README.md

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# torch.ops.build_tree_kernel_efficient
2+
3+
4+
## Function Description | 功能描述
5+
6+
### English:
7+
This is the AscendC version `build_tree_kernel_efficient` kernel function, which organizes the draft model’s multi-step top-k candidate tokens into a verification tree。
8+
9+
Adapted from [CUDA Implementation](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/speculative/eagle_utils.cu)
10+
11+
For each sample it concurrently builds
12+
13+
- tree_mask – which nodes must be verified by the target model
14+
- positions – absolute position of each node in the full sequence
15+
- retrive_* linked lists – allow O(1) navigation to children & siblings
16+
17+
### 中文:
18+
这是AscendC版本的`build_tree_kernel_efficient`内核方法,它将draft模型产生的多步top-k候选token组织成验证树(verification tree)
19+
20+
引用自 [CUDA 实现](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/speculative/eagle_utils.cu)
21+
22+
内核为每个样本并行构造
23+
24+
- 树掩码 tree_mask(标记哪些节点需要被大模型验证)
25+
- 位置编码 positions(节点在完整序列中的位置)
26+
- 检索链表 retrive_*(支持 O(1) 找到子节点与兄弟节点)
27+
28+
29+
## Interface Prototype | 接口原型
30+
31+
### Python Binding Definition
32+
```python
33+
import sgl_kernel_npu
34+
35+
torch.ops.npu.build_tree_kernel_efficient(
36+
parent_list: torch.Tensor, # int64, [batch_size, topk*(depth-1)+1]
37+
selected_index: torch.Tensor, # int64, [batch_size, draft_token_num-1]
38+
verified_seq_len: torch.Tensor, # int64, [batch_size]
39+
tree_mask: torch.Tensor, # bool, [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
40+
# [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
41+
positions: torch.Tensor, # int64, [batch_size, draft_token_num]
42+
retrive_index: torch.Tensor, # int64, [batch_size, draft_token_num]
43+
retrive_next_token: torch.Tensor, # int64, [batch_size, draft_token_num]
44+
retrive_next_sibling: torch.Tensor, # int64, [batch_size, draft_token_num]
45+
topk: int,
46+
depth: int,
47+
draft_token_num: int,
48+
tree_mask_mode: int
49+
) -> None
50+
```
51+
52+
### Kernel Definition | 核函数定义
53+
```C++
54+
extern "C" __global__ __aicore__ void build_tree_efficient(GM_ADDR parent_list,
55+
GM_ADDR selected_index,
56+
GM_ADDR verified_seq_len,
57+
GM_ADDR tree_mask,
58+
GM_ADDR positions,
59+
GM_ADDR retrive_index,
60+
GM_ADDR retrive_next_token,
61+
GM_ADDR retrive_next_sibling,
62+
GM_ADDR workspace_in,
63+
GM_ADDR tiling_in)
64+
```
65+
66+
## Parameter Description | 参数说明
67+
68+
| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 |
69+
|:----------------------|:----------------|:------------------------------------------|:---------------------------|
70+
| `parent_list` | `torch.Tensor` | parent id of every draft token | 每个 draft token 的父节点 id |
71+
| `selected_index` | `torch.Tensor` | flat index of sampled token in top-k list | 采样的 token 在 top-k 列表中的扁平索引 |
72+
| `verified_seq_len` | `torch.Tensor` | length of already-verified prefix | 当前已验证序列长度 |
73+
| `topk` | `int` | branching factor per step | 每步分支数 |
74+
| `depth` | `int` | maximum speculative depth | 最大投机深度 |
75+
| `draft_token_num` | `int` | total #draft tokens per sample | 单样本 draft token 总数 |
76+
| `tree_mask_mode` | `int` | mask layout mode (1=FULL\_MASK) | 掩码布局模式(1=FULL\_MASK) |
77+
78+
79+
## Output Description | 输出说明
80+
81+
| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 |
82+
|:-----------------------|:----------------|:-----------------------------------|:-------------------|
83+
| `tree_mask` | `torch.Tensor` | true → node must be verified | true → 该节点需被验证 |
84+
| `positions` | `torch.Tensor` | absolute position in full sequence | 节点在完整序列中的绝对位置 |
85+
| `retrive_index` | `torch.Tensor` | node → flat index for quick lookup | 快速检索:节点→扁平索引 |
86+
| `retrive_next_token` | `torch.Tensor` | first child id (-1 = none) | 第一个子节点 id(-1 表示无) |
87+
| `retrive_next_sibling` | `torch.Tensor` | next sibling id (-1 = none) | 下一个兄弟节点 id(-1 表示无) |
88+
89+
90+
## Constraints | 约束说明
91+
92+
### English:
93+
`TreeMaskMode.QLEN_ONLY_BITPACKING = 2` is not implemented
94+
95+
### 中文:
96+
`TreeMaskMode.QLEN_ONLY_BITPACKING = 2` 暂未实现
97+
98+
## Example | 调用示例
99+
100+
```python
101+
import math
102+
import sgl_kernel_npu
103+
import torch
104+
import torch_npu
105+
106+
device = torch.device('npu:0')
107+
108+
topk = 4
109+
depth = 4
110+
num_verify_tokens = 8
111+
112+
parent_list=...
113+
top_scores_index=...
114+
seq_lens=...
115+
116+
bs = seq_lens.numel()
117+
118+
tree_mask_mode = TreeMaskMode.FULL_MASK
119+
if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
120+
tree_mask = torch.full(
121+
(num_verify_tokens * bs * num_verify_tokens,),
122+
True,
123+
dtype=torch.bool,
124+
device=device,
125+
)
126+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
127+
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
128+
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
129+
tree_mask = torch.zeros(
130+
(num_verify_tokens * bs,),
131+
dtype=packed_dtypes[packed_dtype_idx],
132+
device=device,
133+
)
134+
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
135+
tree_mask = torch.full(
136+
(
137+
seq_lens_sum * num_verify_tokens
138+
+ num_verify_tokens * num_verify_tokens * bs,
139+
),
140+
True,
141+
device=device,
142+
)
143+
else:
144+
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
145+
146+
retrive_index = torch.full(
147+
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
148+
)
149+
retrive_next_token = torch.full(
150+
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
151+
)
152+
retrive_next_sibling = torch.full(
153+
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
154+
)
155+
156+
positions = torch.empty(
157+
(bs * num_verify_tokens,), device=device, dtype=torch.long
158+
)
159+
160+
torch.ops.npu.build_tree_kernel_efficient(
161+
parent_list,
162+
top_scores_index,
163+
seq_lens,
164+
tree_mask,
165+
positions,
166+
retrive_index,
167+
retrive_next_token,
168+
retrive_next_sibling,
169+
topk,
170+
depth,
171+
num_verify_tokens,
172+
tree_mask_mode,
173+
)
174+
```

0 commit comments

Comments
 (0)