File tree Expand file tree Collapse file tree 1 file changed +38
-19
lines changed
tests/python/sgl_kernel_npu Expand file tree Collapse file tree 1 file changed +38
-19
lines changed Original file line number Diff line number Diff line change @@ -123,25 +123,44 @@ def build_tree_kernel_efficient(
123123 positions = torch .empty (
124124 (bs * num_verify_tokens ,), device = device , dtype = torch .long
125125 )
126- (
127- positions ,
128- retrive_index ,
129- retrive_next_token ,
130- retrive_next_sibling ,
131- tree_mask ,
132- ) = build_tree_efficient_native (
133- parent_list ,
134- top_scores_index ,
135- seq_lens ,
136- tree_mask ,
137- retrive_index ,
138- retrive_next_token ,
139- retrive_next_sibling ,
140- topk ,
141- num_verify_tokens ,
142- tree_mask_mode ,
143- bs ,
144- )
126+
127+ try :
128+ import sgl_kernel_npu
129+
130+ torch .ops .npu .build_tree_kernel_efficient (
131+ parent_list ,
132+ top_scores_index ,
133+ seq_lens ,
134+ tree_mask ,
135+ positions ,
136+ retrive_index ,
137+ retrive_next_token ,
138+ retrive_next_sibling ,
139+ topk ,
140+ spec_steps ,
141+ num_verify_tokens ,
142+ tree_mask_mode ,
143+ )
144+ except ImportError :
145+ (
146+ positions ,
147+ retrive_index ,
148+ retrive_next_token ,
149+ retrive_next_sibling ,
150+ tree_mask ,
151+ ) = build_tree_efficient_native (
152+ parent_list ,
153+ top_scores_index ,
154+ seq_lens ,
155+ tree_mask ,
156+ retrive_index ,
157+ retrive_next_token ,
158+ retrive_next_sibling ,
159+ topk ,
160+ num_verify_tokens ,
161+ tree_mask_mode ,
162+ bs ,
163+ )
145164
146165 return (
147166 tree_mask ,
You can’t perform that action at this time.
0 commit comments