Skip to content

Commit ce8df08

Browse files
committed
Fix the bugs about operator registration about PyTorch Dispatcher
**Background:** There are two principles about operator registration in PyTorch - The same namespace can be only registered once by `TORCH_LIBRARY` - The operator signatures can be only registered once by `def` Considering that all custom operators defined in the current repo are only used by Ascend, instead of defining a common operator pattern by vLLM, all accelerators then follow this operator and complete the implementation based on their respective hardware, which is conducive to module functional abstraction. Therefore, we can rename the operator registration namespace to an Ascend-specific namespace. Signed-off-by: FFFrog <ljw1101.vip@gmail.com>
1 parent 6d8bc38 commit ce8df08

16 files changed

+96
-64
lines changed

benchmarks/ops/ben_vocabparallelembedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def ref_fn():
112112

113113
# Define custom function
114114
def custom_fn():
115-
return torch.ops._C.get_masked_input_and_mask(
115+
return torch.ops._C_ascend.get_masked_input_and_mask(
116116
input_tensor,
117117
test_case["org_start"],
118118
test_case["org_end"],

csrc/torch_binding.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
141141
TP2, rank 1:
142142
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
143143
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
144-
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
144+
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
145145
Parameters:
146146
org_vocab_start_index //base embeddings start
147147
org_vocab_end_index //base embeddings end
@@ -164,22 +164,22 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
164164
// Create output tensors
165165
at::Tensor masked_input = at::empty_like(input);
166166
at::Tensor mask = at::empty_like(input).to(at::kBool);
167-
167+
168168
// Get data pointers
169169
void *input_ptr = input.data_ptr();
170170
void *masked_input_ptr = masked_input.data_ptr();
171171
void *mask_ptr = mask.data_ptr();
172-
172+
173173
// Get current stream
174174
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
175-
175+
176176
// Get scalar type
177177
at::ScalarType scalar_type = input.scalar_type();
178-
178+
179179
// Create and configure OpCommand
180180
at_npu::native::OpCommand cmd;
181181
cmd.Name("get_masked_input_and_mask");
182-
cmd.SetCustomHandler([scalar_type, size, stream,
182+
cmd.SetCustomHandler([scalar_type, size, stream,
183183
input_ptr, masked_input_ptr, mask_ptr,
184184
org_vocab_start_index, org_vocab_end_index,
185185
num_org_vocab_padding, added_vocab_start_index,
@@ -193,7 +193,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
193193
get_masked_input_and_mask_impl(
194194
stream,
195195
input_ptr,
196-
masked_input_ptr,
196+
masked_input_ptr,
197197
mask_ptr,
198198
org_vocab_start_index,
199199
org_vocab_end_index,
@@ -203,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
203203
size,
204204
loop_cnt,
205205
aiv_num);
206-
206+
207207
return 0;
208208
});
209209
cmd.Run();
@@ -320,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
320320
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
321321
at_npu::native::OpCommand cmd;
322322
cmd.Name("sgmv_shrink");
323-
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
324-
seq_len_ptr, seq_len_size, y_ptr,
323+
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
324+
seq_len_ptr, seq_len_size, y_ptr,
325325
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
326326
auto dtype = get_dtype_from_torch(scalar_type);
327327
int device_id = 0;
@@ -330,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
330330
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
331331
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
332332
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
333-
y_ptr, batch_size,
333+
y_ptr, batch_size,
334334
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
335335
return 0;
336336
});
@@ -367,15 +367,15 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
367367
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
368368
at_npu::native::OpCommand cmd;
369369
cmd.Name("sgmv_expand");
370-
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
370+
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
371371
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
372372
auto dtype = get_dtype_from_torch(scalar_type);
373373
int device_id = 0;
374374
int64_t aiv_num = 0;
375375
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
376376
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
377377
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
378-
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
378+
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
379379
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
380380
return 0;
381381
});
@@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
384384
}
385385
} // namespace vllm_ascend
386386

387-
TORCH_LIBRARY_EXPAND(_C, ops)
387+
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
388388
{
389389
// vLLM-Ascend custom ops
390390
ops.def("weak_ref_tensor(Tensor input) -> Tensor");

csrc/torch_binding_meta.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
4040
at::Tensor &positions,
4141
at::Tensor &query,
4242
at::Tensor &key,
43-
int64_t head_size,
43+
int64_t head_size,
4444
at::Tensor &cos_sin_cache,
4545
bool is_neox) {
4646
auto num_tokens = positions.sym_numel();
@@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
8686
} // namespace vllm_ascend
8787

8888
namespace {
89-
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
89+
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
9090
// the custom kernel been captured into aclgraph
91-
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
91+
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
9292
// Rotary embedding meta implementation
9393
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
9494
// Masked input and mask meta implementation
@@ -99,4 +99,4 @@ namespace {
9999
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
100100

101101
}
102-
}
102+
}

tests/e2e/singlecard/ops/test_bgmv_expand.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_bgmv_expand():
3333
y_npu = y.npu()
3434

3535
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
36-
y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0,
37-
128)
36+
y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu,
37+
y_npu, 0, 128)
3838

3939
# Compare the results.
4040
torch.testing.assert_close(y_out_npu.cpu(),

tests/e2e/singlecard/ops/test_bgmv_shrink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_bgmv_shrink():
3333
y_npu = y.npu()
3434

3535
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
36-
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
36+
torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
3737

3838
# Compare the results.
3939
torch.testing.assert_close(y_npu.cpu(),

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim(
182182
)
183183

184184
ref_query, ref_key = rope.forward_native(positions, query, key)
185-
query, key = torch.ops._C.rotary_embedding(
185+
query, key = torch.ops._C_ascend.rotary_embedding(
186186
positions,
187187
query,
188188
key,
@@ -239,7 +239,7 @@ def forward(
239239
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
240240
qkv = self.qkv_proj(hidden_states)
241241
q, k, v = qkv.chunk(3, dim=-1)
242-
query, key = torch.ops._C.rotary_embedding(
242+
query, key = torch.ops._C_ascend.rotary_embedding(
243243
positions,
244244
q,
245245
k,

tests/e2e/singlecard/ops/test_vocabparallelembedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_get_masked_input_and_mask(
7272

7373
# Get custom op result
7474
print("input_tensor:", input_tensor)
75-
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
75+
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
7676
input_tensor, test_case["org_start"], test_case["org_end"],
7777
test_case["padding"], test_case["added_start"], test_case["added_end"])
7878

tests/ut/ops/test_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def setUp(self):
9494
self.mock_self.cos_sin_cache = self.cos_sin_cache
9595
self.mock_self.is_neox_style = self.is_neox_style
9696

97-
@patch('torch.ops._C')
97+
@patch('torch.ops._C_ascend')
9898
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
9999
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
100100
return_value=True)

tests/ut/torchair/ops/test_torchair_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_rope_forward_oot_torchair_enabled_base(self,
104104
self.assertTrue(torch.equal(result_q, self.query))
105105
self.assertTrue(torch.equal(result_k, self.key))
106106

107-
@patch('torch.ops._C')
107+
@patch('torch.ops._C_ascend')
108108
@patch(
109109
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
110110
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',

vllm_ascend/compilation/acl_graph.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from vllm.forward_context import BatchDescriptor, get_forward_context
1616
from vllm.logger import logger
1717
from vllm.platforms import current_platform
18-
from vllm.utils import weak_ref_tensors
18+
19+
from ..utils import weak_ref_tensors
1920

2021

2122
@dataclasses.dataclass
@@ -35,10 +36,10 @@ class ACLGraphWrapper:
3536
3637
The workflow of this wrapper in the aclgraph dispatching is as follows:
3738
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
38-
PIECEWISE).
39-
2. At runtime, the wrapper receives a runtime_mode and a
39+
PIECEWISE).
40+
2. At runtime, the wrapper receives a runtime_mode and a
4041
batch_descriptor(key) from the forward context and blindly trust them
41-
for aclgraph dispatching.
42+
for aclgraph dispatching.
4243
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
4344
wrapper, just call the runnable directly.
4445
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
@@ -47,9 +48,9 @@ class ACLGraphWrapper:
4748
4849
Note: ACLGraphWrapper does not store persistent buffers or copy any
4950
runtime inputs into that buffers for replay. We assume implementing them
50-
is done outside of the wrapper. That is because we do not make any
51+
is done outside of the wrapper. That is because we do not make any
5152
assumption on the dynamic shape (batch size) of the runtime inputs, as a
52-
trade-off for staying orthogonal to compilation logic. Nevertheless,
53+
trade-off for staying orthogonal to compilation logic. Nevertheless,
5354
tracing and checking the input addresses to be consistent during replay is
5455
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
5556
"""

0 commit comments

Comments
 (0)