Skip to content

Commit 5e748e8

Browse files
hw-csongsongchenyi
authored andcommitted
add test cases for build_tree_efficient ascendc ops
1 parent 30b8eef commit 5e748e8

File tree

1 file changed

+38
-19
lines changed

1 file changed

+38
-19
lines changed

tests/python/sgl_kernel_npu/test_build_tree.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,25 +123,44 @@ def build_tree_kernel_efficient(
123123
positions = torch.empty(
124124
(bs * num_verify_tokens,), device=device, dtype=torch.long
125125
)
126-
(
127-
positions,
128-
retrive_index,
129-
retrive_next_token,
130-
retrive_next_sibling,
131-
tree_mask,
132-
) = build_tree_efficient_native(
133-
parent_list,
134-
top_scores_index,
135-
seq_lens,
136-
tree_mask,
137-
retrive_index,
138-
retrive_next_token,
139-
retrive_next_sibling,
140-
topk,
141-
num_verify_tokens,
142-
tree_mask_mode,
143-
bs,
144-
)
126+
127+
try:
128+
import sgl_kernel_npu
129+
130+
torch.ops.npu.build_tree_kernel_efficient(
131+
parent_list,
132+
top_scores_index,
133+
seq_lens,
134+
tree_mask,
135+
positions,
136+
retrive_index,
137+
retrive_next_token,
138+
retrive_next_sibling,
139+
topk,
140+
spec_steps,
141+
num_verify_tokens,
142+
tree_mask_mode,
143+
)
144+
except ImportError:
145+
(
146+
positions,
147+
retrive_index,
148+
retrive_next_token,
149+
retrive_next_sibling,
150+
tree_mask,
151+
) = build_tree_efficient_native(
152+
parent_list,
153+
top_scores_index,
154+
seq_lens,
155+
tree_mask,
156+
retrive_index,
157+
retrive_next_token,
158+
retrive_next_sibling,
159+
topk,
160+
num_verify_tokens,
161+
tree_mask_mode,
162+
bs,
163+
)
145164

146165
return (
147166
tree_mask,

0 commit comments

Comments
 (0)