|
| 1 | +# torch.ops.build_tree_kernel_efficient |
| 2 | + |
| 3 | + |
| 4 | +## Function Description | 功能描述 |
| 5 | + |
| 6 | +### English: |
| 7 | +This is the AscendC version `build_tree_kernel_efficient` kernel function, which organizes the draft model’s multi-step top-k candidate tokens into a verification tree。 |
| 8 | + |
| 9 | +Adapted from [CUDA Implementation](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/speculative/eagle_utils.cu) |
| 10 | + |
| 11 | +For each sample it concurrently builds |
| 12 | + |
| 13 | +- tree_mask – which nodes must be verified by the target model |
| 14 | +- positions – absolute position of each node in the full sequence |
| 15 | +- retrive_* linked lists – allow O(1) navigation to children & siblings |
| 16 | + |
| 17 | +### 中文: |
| 18 | +这是AscendC版本的`build_tree_kernel_efficient`内核方法,它将draft模型产生的多步top-k候选token组织成验证树(verification tree) |
| 19 | + |
| 20 | +引用自 [CUDA 实现](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/speculative/eagle_utils.cu) |
| 21 | + |
| 22 | +内核为每个样本并行构造 |
| 23 | + |
| 24 | +- 树掩码 tree_mask(标记哪些节点需要被大模型验证) |
| 25 | +- 位置编码 positions(节点在完整序列中的位置) |
| 26 | +- 检索链表 retrive_*(支持 O(1) 找到子节点与兄弟节点) |
| 27 | + |
| 28 | + |
| 29 | +## Interface Prototype | 接口原型 |
| 30 | + |
| 31 | +### Python Binding Definition |
| 32 | +```python |
| 33 | +import sgl_kernel_npu |
| 34 | + |
| 35 | +torch.ops.npu.build_tree_kernel_efficient( |
| 36 | + parent_list: torch.Tensor, # int64, [batch_size, topk*(depth-1)+1] |
| 37 | + selected_index: torch.Tensor, # int64, [batch_size, draft_token_num-1] |
| 38 | + verified_seq_len: torch.Tensor, # int64, [batch_size] |
| 39 | + tree_mask: torch.Tensor, # bool, [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = |
| 40 | + # [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] |
| 41 | + positions: torch.Tensor, # int64, [batch_size, draft_token_num] |
| 42 | + retrive_index: torch.Tensor, # int64, [batch_size, draft_token_num] |
| 43 | + retrive_next_token: torch.Tensor, # int64, [batch_size, draft_token_num] |
| 44 | + retrive_next_sibling: torch.Tensor, # int64, [batch_size, draft_token_num] |
| 45 | + topk: int, |
| 46 | + depth: int, |
| 47 | + draft_token_num: int, |
| 48 | + tree_mask_mode: int |
| 49 | +) -> None |
| 50 | +``` |
| 51 | + |
| 52 | +### Kernel Definition | 核函数定义 |
| 53 | +```C++ |
| 54 | +extern "C" __global__ __aicore__ void build_tree_efficient(GM_ADDR parent_list, |
| 55 | + GM_ADDR selected_index, |
| 56 | + GM_ADDR verified_seq_len, |
| 57 | + GM_ADDR tree_mask, |
| 58 | + GM_ADDR positions, |
| 59 | + GM_ADDR retrive_index, |
| 60 | + GM_ADDR retrive_next_token, |
| 61 | + GM_ADDR retrive_next_sibling, |
| 62 | + GM_ADDR workspace_in, |
| 63 | + GM_ADDR tiling_in) |
| 64 | +``` |
| 65 | +
|
| 66 | +## Parameter Description | 参数说明 |
| 67 | +
|
| 68 | +| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 | |
| 69 | +|:----------------------|:----------------|:------------------------------------------|:---------------------------| |
| 70 | +| `parent_list` | `torch.Tensor` | parent id of every draft token | 每个 draft token 的父节点 id | |
| 71 | +| `selected_index` | `torch.Tensor` | flat index of sampled token in top-k list | 采样的 token 在 top-k 列表中的扁平索引 | |
| 72 | +| `verified_seq_len` | `torch.Tensor` | length of already-verified prefix | 当前已验证序列长度 | |
| 73 | +| `topk` | `int` | branching factor per step | 每步分支数 | |
| 74 | +| `depth` | `int` | maximum speculative depth | 最大投机深度 | |
| 75 | +| `draft_token_num` | `int` | total #draft tokens per sample | 单样本 draft token 总数 | |
| 76 | +| `tree_mask_mode` | `int` | mask layout mode (1=FULL\_MASK) | 掩码布局模式(1=FULL\_MASK) | |
| 77 | +
|
| 78 | +
|
| 79 | +## Output Description | 输出说明 |
| 80 | +
|
| 81 | +| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 | |
| 82 | +|:-----------------------|:----------------|:-----------------------------------|:-------------------| |
| 83 | +| `tree_mask` | `torch.Tensor` | true → node must be verified | true → 该节点需被验证 | |
| 84 | +| `positions` | `torch.Tensor` | absolute position in full sequence | 节点在完整序列中的绝对位置 | |
| 85 | +| `retrive_index` | `torch.Tensor` | node → flat index for quick lookup | 快速检索:节点→扁平索引 | |
| 86 | +| `retrive_next_token` | `torch.Tensor` | first child id (-1 = none) | 第一个子节点 id(-1 表示无) | |
| 87 | +| `retrive_next_sibling` | `torch.Tensor` | next sibling id (-1 = none) | 下一个兄弟节点 id(-1 表示无) | |
| 88 | +
|
| 89 | +
|
| 90 | +## Constraints | 约束说明 |
| 91 | +
|
| 92 | +### English: |
| 93 | +`TreeMaskMode.QLEN_ONLY_BITPACKING = 2` is not implemented |
| 94 | +
|
| 95 | +### 中文: |
| 96 | +`TreeMaskMode.QLEN_ONLY_BITPACKING = 2` 暂未实现 |
| 97 | +
|
| 98 | +## Example | 调用示例 |
| 99 | +
|
| 100 | +```python |
| 101 | +import math |
| 102 | +import sgl_kernel_npu |
| 103 | +import torch |
| 104 | +import torch_npu |
| 105 | +
|
| 106 | +device = torch.device('npu:0') |
| 107 | +
|
| 108 | +topk = 4 |
| 109 | +depth = 4 |
| 110 | +num_verify_tokens = 8 |
| 111 | +
|
| 112 | +parent_list=... |
| 113 | +top_scores_index=... |
| 114 | +seq_lens=... |
| 115 | +
|
| 116 | +bs = seq_lens.numel() |
| 117 | +
|
| 118 | +tree_mask_mode = TreeMaskMode.FULL_MASK |
| 119 | +if tree_mask_mode == TreeMaskMode.QLEN_ONLY: |
| 120 | + tree_mask = torch.full( |
| 121 | + (num_verify_tokens * bs * num_verify_tokens,), |
| 122 | + True, |
| 123 | + dtype=torch.bool, |
| 124 | + device=device, |
| 125 | + ) |
| 126 | +elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: |
| 127 | + packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] |
| 128 | + packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) |
| 129 | + tree_mask = torch.zeros( |
| 130 | + (num_verify_tokens * bs,), |
| 131 | + dtype=packed_dtypes[packed_dtype_idx], |
| 132 | + device=device, |
| 133 | + ) |
| 134 | +elif tree_mask_mode == TreeMaskMode.FULL_MASK: |
| 135 | + tree_mask = torch.full( |
| 136 | + ( |
| 137 | + seq_lens_sum * num_verify_tokens |
| 138 | + + num_verify_tokens * num_verify_tokens * bs, |
| 139 | + ), |
| 140 | + True, |
| 141 | + device=device, |
| 142 | + ) |
| 143 | +else: |
| 144 | + raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") |
| 145 | +
|
| 146 | +retrive_index = torch.full( |
| 147 | + (bs, num_verify_tokens), -1, device=device, dtype=torch.long |
| 148 | +) |
| 149 | +retrive_next_token = torch.full( |
| 150 | + (bs, num_verify_tokens), -1, device=device, dtype=torch.long |
| 151 | +) |
| 152 | +retrive_next_sibling = torch.full( |
| 153 | + (bs, num_verify_tokens), -1, device=device, dtype=torch.long |
| 154 | +) |
| 155 | +
|
| 156 | +positions = torch.empty( |
| 157 | + (bs * num_verify_tokens,), device=device, dtype=torch.long |
| 158 | +) |
| 159 | +
|
| 160 | +torch.ops.npu.build_tree_kernel_efficient( |
| 161 | + parent_list, |
| 162 | + top_scores_index, |
| 163 | + seq_lens, |
| 164 | + tree_mask, |
| 165 | + positions, |
| 166 | + retrive_index, |
| 167 | + retrive_next_token, |
| 168 | + retrive_next_sibling, |
| 169 | + topk, |
| 170 | + depth, |
| 171 | + num_verify_tokens, |
| 172 | + tree_mask_mode, |
| 173 | +) |
| 174 | +``` |
0 commit comments