@@ -141,7 +141,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
141
141
TP2, rank 1:
142
142
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
143
143
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
144
- index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
144
+ index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
145
145
Parameters:
146
146
org_vocab_start_index //base embeddings start
147
147
org_vocab_end_index //base embeddings end
@@ -164,22 +164,22 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
164
164
// Create output tensors
165
165
at::Tensor masked_input = at::empty_like (input);
166
166
at::Tensor mask = at::empty_like (input).to (at::kBool );
167
-
167
+
168
168
// Get data pointers
169
169
void *input_ptr = input.data_ptr ();
170
170
void *masked_input_ptr = masked_input.data_ptr ();
171
171
void *mask_ptr = mask.data_ptr ();
172
-
172
+
173
173
// Get current stream
174
174
aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
175
-
175
+
176
176
// Get scalar type
177
177
at::ScalarType scalar_type = input.scalar_type ();
178
-
178
+
179
179
// Create and configure OpCommand
180
180
at_npu::native::OpCommand cmd;
181
181
cmd.Name (" get_masked_input_and_mask" );
182
- cmd.SetCustomHandler ([scalar_type, size, stream,
182
+ cmd.SetCustomHandler ([scalar_type, size, stream,
183
183
input_ptr, masked_input_ptr, mask_ptr,
184
184
org_vocab_start_index, org_vocab_end_index,
185
185
num_org_vocab_padding, added_vocab_start_index,
@@ -193,7 +193,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
193
193
get_masked_input_and_mask_impl (
194
194
stream,
195
195
input_ptr,
196
- masked_input_ptr,
196
+ masked_input_ptr,
197
197
mask_ptr,
198
198
org_vocab_start_index,
199
199
org_vocab_end_index,
@@ -203,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
203
203
size,
204
204
loop_cnt,
205
205
aiv_num);
206
-
206
+
207
207
return 0 ;
208
208
});
209
209
cmd.Run ();
@@ -320,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
320
320
aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
321
321
at_npu::native::OpCommand cmd;
322
322
cmd.Name (" sgmv_shrink" );
323
- cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
324
- seq_len_ptr, seq_len_size, y_ptr,
323
+ cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
324
+ seq_len_ptr, seq_len_size, y_ptr,
325
325
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
326
326
auto dtype = get_dtype_from_torch (scalar_type);
327
327
int device_id = 0 ;
@@ -330,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
330
330
int num_tokens_per_core = (batch_size + aiv_num - 1 ) / aiv_num;
331
331
TORCH_CHECK (" num_tokens_per_core != 0" , " num_tokens_per_core should not be 0" );
332
332
sgmv_shrink_impl (dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
333
- y_ptr, batch_size,
333
+ y_ptr, batch_size,
334
334
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
335
335
return 0 ;
336
336
});
@@ -367,15 +367,15 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
367
367
aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
368
368
at_npu::native::OpCommand cmd;
369
369
cmd.Name (" sgmv_expand" );
370
- cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
370
+ cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
371
371
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
372
372
auto dtype = get_dtype_from_torch (scalar_type);
373
373
int device_id = 0 ;
374
374
int64_t aiv_num = 0 ;
375
375
TORCH_CHECK (aclGetDeviceCapability (device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
376
376
int num_tokens_per_core = (batch_size + aiv_num - 1 ) / aiv_num;
377
377
TORCH_CHECK (" num_tokens_per_core != 0" , " num_tokens_per_core should not be 0" );
378
- sgmv_expand_impl (dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
378
+ sgmv_expand_impl (dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
379
379
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
380
380
return 0 ;
381
381
});
@@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
384
384
}
385
385
} // namespace vllm_ascend
386
386
387
- TORCH_LIBRARY_EXPAND (_C , ops)
387
+ TORCH_LIBRARY_EXPAND (CONCAT(_C, _ascend) , ops)
388
388
{
389
389
// vLLM-Ascend custom ops
390
390
ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
0 commit comments